Spaces:
Runtime error
Runtime error
Commit
·
5769ee4
0
Parent(s):
Removed history to avoid any unverified information being released
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +31 -0
- .gitignore +14 -0
- README.md +93 -0
- app.py +3 -0
- export_waymo_to_json.py +94 -0
- image/illustration.png +0 -0
- import_dataset_from_huggingface.py +55 -0
- import_model_from_huggingface.py +16 -0
- notebooks/visualize_planner_evaluation_results.ipynb +344 -0
- requirements.txt +16 -0
- risk_biased/__init__.py +0 -0
- risk_biased/config/learning_config.py +156 -0
- risk_biased/config/paths.py +9 -0
- risk_biased/config/planning_config.py +13 -0
- risk_biased/config/waymo_config.py +104 -0
- risk_biased/models/__init__.py +0 -0
- risk_biased/models/biased_cvae_model.py +907 -0
- risk_biased/models/context_gating.py +53 -0
- risk_biased/models/cvae_decoder.py +388 -0
- risk_biased/models/cvae_encoders.py +376 -0
- risk_biased/models/cvae_params.py +78 -0
- risk_biased/models/latent_distributions.py +468 -0
- risk_biased/models/map_encoder.py +38 -0
- risk_biased/models/mlp.py +60 -0
- risk_biased/models/multi_head_attention.py +81 -0
- risk_biased/models/nn_blocks.py +626 -0
- risk_biased/mpc_planner/__init__.py +0 -0
- risk_biased/mpc_planner/dynamics.py +49 -0
- risk_biased/mpc_planner/planner.py +332 -0
- risk_biased/mpc_planner/planner_cost.py +127 -0
- risk_biased/mpc_planner/solver.py +429 -0
- risk_biased/predictors/biased_predictor.py +568 -0
- risk_biased/scene_dataset/__init__.py +0 -0
- risk_biased/scene_dataset/loaders.py +252 -0
- risk_biased/scene_dataset/pedestrian.py +165 -0
- risk_biased/scene_dataset/scene.py +522 -0
- risk_biased/scene_dataset/scene_plotter.py +276 -0
- risk_biased/utils/__init__.py +0 -0
- risk_biased/utils/callbacks.py +595 -0
- risk_biased/utils/config_argparse.py +96 -0
- risk_biased/utils/cost.py +539 -0
- risk_biased/utils/load_model.py +220 -0
- risk_biased/utils/loss.py +124 -0
- risk_biased/utils/metrics.py +81 -0
- risk_biased/utils/planner_utils.py +462 -0
- risk_biased/utils/risk.py +571 -0
- risk_biased/utils/torch_utils.py +66 -0
- risk_biased/utils/waymo_dataloader.py +490 -0
- scripts/eval_scripts/compute_stats.py +523 -0
- scripts/eval_scripts/draw_cost.py +84 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/*
|
2 |
+
*/__pycache__/*
|
3 |
+
**/__pycache__/*
|
4 |
+
scene_data_*.npy
|
5 |
+
scripts/logs/*
|
6 |
+
logs/*
|
7 |
+
wandb/*
|
8 |
+
data/*
|
9 |
+
|
10 |
+
*~*
|
11 |
+
*#*
|
12 |
+
*sweep_logs*
|
13 |
+
*.ipynb_checkpoints*
|
14 |
+
*.egg-info*
|
README.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "RAP: Risk-Aware Prediction"
|
3 |
+
emoji: 🚙
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: grey
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.7
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
language:
|
11 |
+
- Python
|
12 |
+
thumbnail: "url to a thumbnail used in social sharing"
|
13 |
+
tags:
|
14 |
+
- Risk Measures
|
15 |
+
- Forecasting
|
16 |
+
- Safety
|
17 |
+
- Human-Robot Interaction
|
18 |
+
license: cc-by-nc-4.0
|
19 |
+
|
20 |
+
---
|
21 |
+
|
22 |
+
# License statement
|
23 |
+
|
24 |
+
The code is provided under a Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. Under the license, the code is provided royalty free for non-commercial purposes only. The code may be covered by patents and if you want to use the code for commercial purposes, please contact us for a different license.
|
25 |
+
|
26 |
+
# RAP: Risk-Aware Prediction
|
27 |
+
|
28 |
+
This is the official code for [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368). You can test the results in [our huggingface demo](https://huggingface.co/spaces/TRI-ML/risk_biased_prediction) and see some additional experiments on the [paper website](https://sites.google.com/view/corl-risk/).
|
29 |
+
|
30 |
+

|
31 |
+
|
32 |
+
We define and train a trajectory forecasting model and bias its prediction towards risk such that it helps a planner to estimate risk by producing the relevant pessimistic trajectory forecasts to consider.
|
33 |
+
|
34 |
+
## Datasets
|
35 |
+
This repository uses two datasets:
|
36 |
+
- A didactic simulated environement with a single vehicle at constant velocity and a single pedestrian.
|
37 |
+
Two pedestrian behavior are implemented: fast and slow. At each step, pedestrians might walk at their favored speed or at the other speed.
|
38 |
+
This produces a distribution of pedestrian trajectories with two modes. The dataset is automatically generated and used. You can change the parameters of the data generation in "config/learning_config.py"
|
39 |
+
- The Waymo Open Motion Dataset (WOMD) with complex real scenes.
|
40 |
+
|
41 |
+
|
42 |
+
## Forecasting model
|
43 |
+
A conditional variational auto-encoder (CVAE) model is used as the base pedestrian trajectory predictor. Its latent space is quantized or gaussian depending on the parameter that you set in the config. It uses either multi-head attention or a modified version of context gating to account for interactions. Depending on the parameters, the trajectory encoder and decoder can be set to MLP, LSTM, or maskedLSTM.
|
44 |
+
|
45 |
+
# Usage
|
46 |
+
|
47 |
+
## Installation
|
48 |
+
|
49 |
+
- (Set up a virtual environment with python>3.7)
|
50 |
+
- Install the packge with `pip -e install .`
|
51 |
+
|
52 |
+
## Setting up the data
|
53 |
+
|
54 |
+
### Didactic simulation
|
55 |
+
- The dataset is automatically generated and used. You can change the parameters of the data generation in "config/learning_config.py"
|
56 |
+
|
57 |
+
### WOMD
|
58 |
+
- [Download the Waymo Open Motion Dataset (WOMD)](https://waymo.com/open/)
|
59 |
+
- Pre-process it as follows:
|
60 |
+
- Sample set: `python scripts/scripts_utils/generate_dataset_waymo.py <data/Waymo>/scenario/validation <data/Waymo>/interactive_veh_type/sample --num_parallel=<16> --debug_size=<1000>`
|
61 |
+
- Training set: `python scripts/interaction_utils/generate_dataset_waymo.py <data/Waymo>/scenario/training <data/Waymo>/interactive_veh_type/training --num_parallel=<16>`
|
62 |
+
- Validation set: `python scripts/interaction_utils/generate_dataset_waymo.py <data/Waymo>/scenario/validation_interactive <data/Waymo>/interactive_veh_type/validation --num_parallel=<16>`
|
63 |
+
|
64 |
+
Replace the arguments:
|
65 |
+
- `<data/Waymo>` with the path where you downloaded WOMD
|
66 |
+
- `<16>` with the number of cores you want to use
|
67 |
+
- `<1000>` with the number of scene to process for the sample set (some scenes are filtered out so the resulting number of pre-processed scenes might be about the third of the input number)
|
68 |
+
- Set up the path to the dataset in "risk_biased/config/paths.py"
|
69 |
+
|
70 |
+
## Configuration and training
|
71 |
+
|
72 |
+
- Set up the output log path in "risk_biased/config/paths.py"
|
73 |
+
- You might need to login to wandb with `wandb login <option> <key>...`
|
74 |
+
- All the parameters defined in "risk_biased/config/learning_config.py" or "risk_biased/config/waymo_config.py" can be overwritten with a command line argument.
|
75 |
+
- To start from a WandB checkpoint, use the option `--load_from "<wandb id>"`. If you wish to force the usage of the local configuration instead of the checkpoint configuration, add the option `--force_config`. If you want to load the last checkpoint instead of the best one, add the option `--load_last`.
|
76 |
+
|
77 |
+
### Didactic simulation
|
78 |
+
- Choose the parameters to set in "risk_biased/config/learning_config.py"
|
79 |
+
- Start training: `python scripts/train_didactic.py`
|
80 |
+
|
81 |
+
### WOMD
|
82 |
+
- Choose the parameters to set in "risk_biased/config/waymo_config.py"
|
83 |
+
- Start traning: `python scripts/train_interaction.py`
|
84 |
+
|
85 |
+
Training has two phases: training the unbiased predictor then training the biased encoder with a frozen predictor model. This second step need to draw many samples to estimate the risk. It is possible that your GPU runs out of memory at this stage. If it does consider reducing the batch size and reducing the number of samples "n_mc_samples_biased". If the number of samples "n_mc_samples_risk" is kept high, the risk estimation will be more accurate but training might be very slow.
|
86 |
+
|
87 |
+
## Evaluation
|
88 |
+
|
89 |
+
Many evaluation scripts are available in "scripts/eval_scripts", to compute results, plot graphs, draw the didactic experiment scene etc...
|
90 |
+
|
91 |
+
You can also run the interactive interface locally with `python scripts/scripts_utils/plotly_interface.py --load_from=<full path to the .ckpt checkpoint file> --cfg_path=<full path to the learning_config.py file from the checkpoint>`
|
92 |
+
Sadly the WOMD license does not allow us to provide the pre-trained weights of our model so you will need to train it yourself.
|
93 |
+
|
app.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from scripts.scripts_utils.plotly_interface import main
|
2 |
+
|
3 |
+
main()
|
export_waymo_to_json.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from json import JSONEncoder
|
3 |
+
from mmcv import Config
|
4 |
+
import numpy
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
|
8 |
+
|
9 |
+
|
10 |
+
class NumpyArrayEncoder(JSONEncoder):
|
11 |
+
def default(self, obj):
|
12 |
+
if isinstance(obj, numpy.ndarray):
|
13 |
+
return obj.tolist()
|
14 |
+
return JSONEncoder.default(self, obj)
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
output_path = "../risk_biased_dataset/data.json"
|
18 |
+
config_path = "risk_biased/config/waymo_config.py"
|
19 |
+
cfg = Config.fromfile(config_path)
|
20 |
+
dataloaders = WaymoDataloaders(cfg)
|
21 |
+
sample_dataloader = dataloaders.sample_dataloader()
|
22 |
+
(
|
23 |
+
x,
|
24 |
+
mask_x,
|
25 |
+
y,
|
26 |
+
mask_y,
|
27 |
+
mask_loss,
|
28 |
+
map_data,
|
29 |
+
mask_map,
|
30 |
+
offset,
|
31 |
+
x_ego,
|
32 |
+
y_ego,
|
33 |
+
) = sample_dataloader.collate_fn(sample_dataloader.dataset)
|
34 |
+
|
35 |
+
batch_size, n_agents, n_timesteps_past, n_features = x.shape
|
36 |
+
n_timesteps_future = y.shape[2]
|
37 |
+
n_features_map = map_data.shape[3]
|
38 |
+
n_features_offset = offset.shape[2]
|
39 |
+
|
40 |
+
print(x.shape)
|
41 |
+
print(mask_x.shape)
|
42 |
+
print(y.shape)
|
43 |
+
print(mask_y.shape)
|
44 |
+
print(mask_loss.shape)
|
45 |
+
print(map_data.shape)
|
46 |
+
print(mask_map.shape)
|
47 |
+
print(offset.shape)
|
48 |
+
print(x_ego.shape)
|
49 |
+
print(y_ego.shape)
|
50 |
+
|
51 |
+
|
52 |
+
data = {"x": x.numpy(),
|
53 |
+
"mask_x": mask_x.numpy(),
|
54 |
+
"y": y.numpy(),
|
55 |
+
"mask_y": mask_y.numpy(),
|
56 |
+
"mask_loss": mask_loss.numpy(),
|
57 |
+
"map_data": map_data.numpy(),
|
58 |
+
"mask_map": mask_map.numpy(),
|
59 |
+
"offset": offset.numpy(),
|
60 |
+
"x_ego": x_ego.numpy(),
|
61 |
+
"y_ego": y_ego.numpy(),
|
62 |
+
}
|
63 |
+
|
64 |
+
json_data = json.dumps(data, cls=NumpyArrayEncoder)
|
65 |
+
|
66 |
+
with open(output_path, "w+") as f:
|
67 |
+
f.write(json_data)
|
68 |
+
|
69 |
+
with open(output_path, "r") as f:
|
70 |
+
decoded = json.load(f)
|
71 |
+
|
72 |
+
x_c = torch.from_numpy(numpy.array(decoded["x"]).astype(numpy.float32))
|
73 |
+
mask_x_c = torch.from_numpy(numpy.array(decoded["mask_x"]).astype(numpy.bool8))
|
74 |
+
y_c = torch.from_numpy(numpy.array(decoded["y"]).astype(numpy.float32))
|
75 |
+
mask_y_c = torch.from_numpy(numpy.array(decoded["mask_y"]).astype(numpy.bool8))
|
76 |
+
mask_loss_c = torch.from_numpy( numpy.array(decoded["mask_loss"]).astype(numpy.bool8))
|
77 |
+
map_data_c = torch.from_numpy(numpy.array(decoded["map_data"]).astype(numpy.float32))
|
78 |
+
mask_map_c = torch.from_numpy(numpy.array(decoded["mask_map"]).astype(numpy.bool8))
|
79 |
+
offset_c = torch.from_numpy(numpy.array(decoded["offset"]).astype(numpy.float32))
|
80 |
+
x_ego_c = torch.from_numpy(numpy.array(decoded["x_ego"]).astype(numpy.float32))
|
81 |
+
y_ego_c = torch.from_numpy(numpy.array(decoded["y_ego"]).astype(numpy.float32))
|
82 |
+
|
83 |
+
assert torch.allclose(x, x_c)
|
84 |
+
assert torch.allclose(mask_x, mask_x_c)
|
85 |
+
assert torch.allclose(y, y_c)
|
86 |
+
assert torch.allclose(mask_y, mask_y_c)
|
87 |
+
assert torch.allclose(mask_loss, mask_loss_c)
|
88 |
+
assert torch.allclose(map_data, map_data_c)
|
89 |
+
assert torch.allclose(mask_map, mask_map_c)
|
90 |
+
assert torch.allclose(offset, offset_c)
|
91 |
+
assert torch.allclose(x_ego, x_ego_c)
|
92 |
+
assert torch.allclose(y_ego, y_ego_c)
|
93 |
+
|
94 |
+
print("All good!")
|
image/illustration.png
ADDED
![]() |
import_dataset_from_huggingface.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import datasets
|
3 |
+
import json
|
4 |
+
from mmcv import Config
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
|
9 |
+
|
10 |
+
|
11 |
+
config_path = "risk_biased/config/waymo_config.py"
|
12 |
+
cfg = Config.fromfile(config_path)
|
13 |
+
dataloaders = WaymoDataloaders(cfg)
|
14 |
+
sample_dataloader = dataloaders.sample_dataloader()
|
15 |
+
(
|
16 |
+
x,
|
17 |
+
mask_x,
|
18 |
+
y,
|
19 |
+
mask_y,
|
20 |
+
mask_loss,
|
21 |
+
map_data,
|
22 |
+
mask_map,
|
23 |
+
offset,
|
24 |
+
x_ego,
|
25 |
+
y_ego,
|
26 |
+
) = sample_dataloader.collate_fn(sample_dataloader.dataset)
|
27 |
+
|
28 |
+
# dataset = load_dataset("json", data_files="../risk_biased_dataset/data.json", split="test", field="x")
|
29 |
+
# dataset = load_from_disk("../risk_biased_dataset/data.json")
|
30 |
+
dataset = load_dataset("jmercat/risk_biased_dataset", split="test")
|
31 |
+
|
32 |
+
x_c = torch.from_numpy(numpy.array(dataset["x"]).astype(numpy.float32))
|
33 |
+
mask_x_c = torch.from_numpy(numpy.array(dataset["mask_x"]).astype(numpy.bool8))
|
34 |
+
y_c = torch.from_numpy(numpy.array(dataset["y"]).astype(numpy.float32))
|
35 |
+
mask_y_c = torch.from_numpy(numpy.array(dataset["mask_y"]).astype(numpy.bool8))
|
36 |
+
mask_loss_c = torch.from_numpy( numpy.array(dataset["mask_loss"]).astype(numpy.bool8))
|
37 |
+
map_data_c = torch.from_numpy(numpy.array(dataset["map_data"]).astype(numpy.float32))
|
38 |
+
mask_map_c = torch.from_numpy(numpy.array(dataset["mask_map"]).astype(numpy.bool8))
|
39 |
+
offset_c = torch.from_numpy(numpy.array(dataset["offset"]).astype(numpy.float32))
|
40 |
+
x_ego_c = torch.from_numpy(numpy.array(dataset["x_ego"]).astype(numpy.float32))
|
41 |
+
y_ego_c = torch.from_numpy(numpy.array(dataset["y_ego"]).astype(numpy.float32))
|
42 |
+
|
43 |
+
assert torch.allclose(x, x_c)
|
44 |
+
assert torch.allclose(mask_x, mask_x_c)
|
45 |
+
assert torch.allclose(y, y_c)
|
46 |
+
assert torch.allclose(mask_y, mask_y_c)
|
47 |
+
assert torch.allclose(mask_loss, mask_loss_c)
|
48 |
+
assert torch.allclose(map_data, map_data_c)
|
49 |
+
assert torch.allclose(mask_map, mask_map_c)
|
50 |
+
assert torch.allclose(offset, offset_c)
|
51 |
+
assert torch.allclose(x_ego, x_ego_c)
|
52 |
+
assert torch.allclose(y_ego, y_ego_c)
|
53 |
+
|
54 |
+
print("All good!")
|
55 |
+
|
import_model_from_huggingface.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_url, cached_download
|
2 |
+
from mmcv import Config
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from risk_biased.utils.load_model import get_predictor
|
6 |
+
from risk_biased.utils.torch_utils import load_weights
|
7 |
+
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
|
8 |
+
|
9 |
+
|
10 |
+
config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py")
|
11 |
+
ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu")
|
12 |
+
cfg = Config.fromfile(config_file)
|
13 |
+
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
|
14 |
+
predictor = load_weights(predictor, ckpt)
|
15 |
+
|
16 |
+
print("Model loaded")
|
notebooks/visualize_planner_evaluation_results.ipynb
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "a0d39c3c",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"%matplotlib notebook\n",
|
11 |
+
"# Switch to inline if debugging plotting.\n",
|
12 |
+
"# %matplotlib inline\n",
|
13 |
+
"\n",
|
14 |
+
"import os\n",
|
15 |
+
"import pickle\n",
|
16 |
+
"import sys\n",
|
17 |
+
"\n",
|
18 |
+
"from matplotlib import animation\n",
|
19 |
+
"import matplotlib.pyplot as plt\n",
|
20 |
+
"import numpy as np"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "markdown",
|
25 |
+
"id": "edc5a0af",
|
26 |
+
"metadata": {},
|
27 |
+
"source": [
|
28 |
+
"### Load evaluation results."
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 2,
|
34 |
+
"id": "569e209d",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"# Replace with the directory containing evaluation outputs.\n",
|
39 |
+
"stats_dir = \"../scripts/eval_scripts/logs/planner_eval/run-1rdonjl7_0/\"\n",
|
40 |
+
"scene_type = \"safer_slow\"\n",
|
41 |
+
"risk_level = 1.0\n",
|
42 |
+
"num_samples = 256\n",
|
43 |
+
"\n",
|
44 |
+
"filepath = os.path.join(stats_dir, f\"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl\")\n",
|
45 |
+
"with open(filepath, \"rb\") as infile:\n",
|
46 |
+
" predictor_data = pickle.load(infile)\n",
|
47 |
+
"filepath = os.path.join(stats_dir, f\"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl\")\n",
|
48 |
+
"with open(filepath, \"rb\") as infile:\n",
|
49 |
+
" planner_data = pickle.load(infile)"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "markdown",
|
54 |
+
"id": "6cdbb58a",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"### Find the most relevant episodes."
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 3,
|
63 |
+
"id": "49b8859c",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [],
|
66 |
+
"source": [
|
67 |
+
"def print_riskiest_episodes(data, max_ep_id=10000, num_to_print=10, key=\"interaction_risk\"):\n",
|
68 |
+
" episode_id_risk = []\n",
|
69 |
+
" for episode_id in range(max_ep_id):\n",
|
70 |
+
" if episode_id not in data:\n",
|
71 |
+
" break\n",
|
72 |
+
" episode_id_risk.append((episode_id, data[episode_id][key]))\n",
|
73 |
+
" episode_id_risk = sorted(episode_id_risk, key=lambda x: x[1], reverse=True)\n",
|
74 |
+
" \n",
|
75 |
+
" print(\"Riskiest episodes:\")\n",
|
76 |
+
" for ep_id, risk in episode_id_risk[:num_to_print]:\n",
|
77 |
+
" print(f\"episode id: {ep_id}\\trisk: {risk:0.4f}\")\n",
|
78 |
+
"\n",
|
79 |
+
"def print_largest_risk_difference_episodes(\n",
|
80 |
+
" data_a, \n",
|
81 |
+
" data_b, \n",
|
82 |
+
" max_ep_id=10000, \n",
|
83 |
+
" key=\"interaction_risk\", \n",
|
84 |
+
" num_to_print=10):\n",
|
85 |
+
" \"\"\"Plots the episodes where the risk of a is most larger than that of b.\"\"\"\n",
|
86 |
+
" risk_a = []\n",
|
87 |
+
" risk_b = []\n",
|
88 |
+
" for episode_id in range(max_ep_id):\n",
|
89 |
+
" if episode_id not in data_a or episode_id not in data_b:\n",
|
90 |
+
" break\n",
|
91 |
+
" risk_a.append(data_a[episode_id][key])\n",
|
92 |
+
" risk_b.append(data_b[episode_id][key])\n",
|
93 |
+
" risk_a = np.array(risk_a)\n",
|
94 |
+
" risk_b = np.array(risk_b)\n",
|
95 |
+
" \n",
|
96 |
+
" diff = risk_a - risk_b\n",
|
97 |
+
" indices = np.argsort(diff)[::-1]\n",
|
98 |
+
" \n",
|
99 |
+
" print(\"Episdoes where the first data is risker than the second\")\n",
|
100 |
+
" for episode_id in indices[:num_to_print]:\n",
|
101 |
+
" print(f\"episode_id: {episode_id}\\tfirst data risk: {risk_a[episode_id]:0.4f}\\tsecond data risk: {risk_b[episode_id]:0.4f}\")"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 4,
|
107 |
+
"id": "1b7de380",
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [
|
110 |
+
{
|
111 |
+
"name": "stdout",
|
112 |
+
"output_type": "stream",
|
113 |
+
"text": [
|
114 |
+
"Riskiest episodes:\n",
|
115 |
+
"episode id: 66\trisk: 1.1713\n",
|
116 |
+
"episode id: 48\trisk: 1.0506\n",
|
117 |
+
"episode id: 32\trisk: 1.0245\n",
|
118 |
+
"episode id: 37\trisk: 1.0139\n",
|
119 |
+
"episode id: 39\trisk: 0.9978\n",
|
120 |
+
"episode id: 84\trisk: 0.9266\n",
|
121 |
+
"episode id: 24\trisk: 0.9190\n",
|
122 |
+
"episode id: 79\trisk: 0.8993\n",
|
123 |
+
"episode id: 75\trisk: 0.8989\n",
|
124 |
+
"episode id: 71\trisk: 0.8468\n"
|
125 |
+
]
|
126 |
+
}
|
127 |
+
],
|
128 |
+
"source": [
|
129 |
+
"print_riskiest_episodes(predictor_data)"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": 6,
|
135 |
+
"id": "ba9bc377",
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"name": "stdout",
|
140 |
+
"output_type": "stream",
|
141 |
+
"text": [
|
142 |
+
"Episdoes where the first data is risker than the second\n",
|
143 |
+
"episode_id: 54\tfirst data risk: 1.2152\tsecond data risk: 1.0733\n",
|
144 |
+
"episode_id: 82\tfirst data risk: 0.5233\tsecond data risk: 0.3918\n",
|
145 |
+
"episode_id: 30\tfirst data risk: 1.2241\tsecond data risk: 1.0988\n",
|
146 |
+
"episode_id: 29\tfirst data risk: 0.8171\tsecond data risk: 0.7078\n",
|
147 |
+
"episode_id: 48\tfirst data risk: 1.0183\tsecond data risk: 0.9133\n",
|
148 |
+
"episode_id: 18\tfirst data risk: 0.4989\tsecond data risk: 0.4065\n",
|
149 |
+
"episode_id: 95\tfirst data risk: 0.5201\tsecond data risk: 0.4409\n",
|
150 |
+
"episode_id: 4\tfirst data risk: 0.8029\tsecond data risk: 0.7406\n",
|
151 |
+
"episode_id: 23\tfirst data risk: 0.9542\tsecond data risk: 0.8920\n",
|
152 |
+
"episode_id: 72\tfirst data risk: 1.0446\tsecond data risk: 0.9828\n"
|
153 |
+
]
|
154 |
+
}
|
155 |
+
],
|
156 |
+
"source": [
|
157 |
+
"print_largest_risk_difference_episodes(predictor_data, planner_data)"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "markdown",
|
162 |
+
"id": "5fc3cfc3",
|
163 |
+
"metadata": {},
|
164 |
+
"source": [
|
165 |
+
"### Animate those episodes for the risk-sensitive predictor vs risk-sensitive planner."
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"execution_count": 7,
|
171 |
+
"id": "c03ae0f0",
|
172 |
+
"metadata": {},
|
173 |
+
"outputs": [],
|
174 |
+
"source": [
|
175 |
+
"def animate_episode(solver_infos, ado_positions, ado_predictions):\n",
|
176 |
+
" fig = plt.figure()\n",
|
177 |
+
" ax = plt.axes(xlim=(0, 100), ylim=(-4, 8))\n",
|
178 |
+
" \n",
|
179 |
+
" scatter = ax.scatter([], [])\n",
|
180 |
+
" text = ax.annotate(\"\", (2,7.5))\n",
|
181 |
+
" \n",
|
182 |
+
" def animate(t):\n",
|
183 |
+
" solver_iter = t // 45\n",
|
184 |
+
" timestep = t % 45\n",
|
185 |
+
" ado_position = np.array([[ado_positions[0, timestep, 0], ado_positions[0, timestep, 1]]]) \n",
|
186 |
+
" solver_info = solver_infos[solver_iter]\n",
|
187 |
+
" biased_predicted_ado_positions = solver_info[\"ado_state_future_samples\"][:, 0, timestep, :2].reshape([-1, 2])\n",
|
188 |
+
" predicted_ado_positions = ado_predictions[:num_samples, 0 , timestep, :2].reshape([-1, 2])\n",
|
189 |
+
" len_pred = len(predicted_ado_positions)\n",
|
190 |
+
" len_biased_pred = len(biased_predicted_ado_positions)\n",
|
191 |
+
" ego_positions = solver_info[\"ego_state_future\"][:, 0, timestep, :2]\n",
|
192 |
+
" positions = np.concatenate((ado_position, predicted_ado_positions, biased_predicted_ado_positions, ego_positions)) \n",
|
193 |
+
" scatter.set_offsets(positions)\n",
|
194 |
+
" scatter.set_alpha(0.5)\n",
|
195 |
+
" colors = np.ones(len(positions))*0.4\n",
|
196 |
+
" colors[0] = 0\n",
|
197 |
+
" colors[1:len_pred+1] = 0.6\n",
|
198 |
+
" colors[len_pred+1:len_biased_pred+len_pred+1] = 0.8\n",
|
199 |
+
" scatter.set_array(colors)\n",
|
200 |
+
" total_risk = solver_info[\"total_risk\"].mean()\n",
|
201 |
+
" tracking_cost = solver_info[\"tracking_cost\"].mean()\n",
|
202 |
+
" text_str = \"solver_iter: {}, timestep: {:02d}, total risk: {:0.2f}, tracking cost: {:0.2f}\".format(\n",
|
203 |
+
" solver_iter,\n",
|
204 |
+
" timestep, \n",
|
205 |
+
" total_risk,\n",
|
206 |
+
" tracking_cost,\n",
|
207 |
+
" )\n",
|
208 |
+
" text.set_text(text_str)\n",
|
209 |
+
" return scatter, text\n",
|
210 |
+
" \n",
|
211 |
+
" num_frames = 45 * 10\n",
|
212 |
+
" anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=20, blit=True, save_count=sys.maxsize)\n",
|
213 |
+
" \n",
|
214 |
+
" return anim"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": 8,
|
220 |
+
"id": "519e9eee",
|
221 |
+
"metadata": {},
|
222 |
+
"outputs": [],
|
223 |
+
"source": [
|
224 |
+
"episode_id = 66"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": 9,
|
230 |
+
"id": "dcc04f38",
|
231 |
+
"metadata": {},
|
232 |
+
"outputs": [
|
233 |
+
{
|
234 |
+
"data": {
|
235 |
+
"application/javascript": "/* Put everything inside the global mpl namespace */\n/* global mpl */\nwindow.mpl = {};\n\nmpl.get_websocket_type = function () {\n if (typeof WebSocket !== 'undefined') {\n return WebSocket;\n } else if (typeof MozWebSocket !== 'undefined') {\n return MozWebSocket;\n } else {\n alert(\n 'Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.'\n );\n }\n};\n\nmpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = this.ws.binaryType !== undefined;\n\n if (!this.supports_binary) {\n var warnings = document.getElementById('mpl-warnings');\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent =\n 'This browser does not support binary websocket messages. ' +\n 'Performance may be slow.';\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = document.createElement('div');\n this.root.setAttribute('style', 'display: inline-block');\n this._root_extra_style(this.root);\n\n parent_element.appendChild(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message('supports_binary', { value: fig.supports_binary });\n fig.send_message('send_image_mode', {});\n if (fig.ratio !== 1) {\n fig.send_message('set_device_pixel_ratio', {\n device_pixel_ratio: fig.ratio,\n });\n }\n fig.send_message('refresh', {});\n };\n\n this.imageObj.onload = function () {\n if (fig.image_mode === 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function () {\n fig.ws.close();\n };\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n};\n\nmpl.figure.prototype._init_header = function () {\n var titlebar = document.createElement('div');\n titlebar.classList =\n 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n var titletext = document.createElement('div');\n titletext.classList = 'ui-dialog-title';\n titletext.setAttribute(\n 'style',\n 'width: 100%; text-align: center; padding: 3px;'\n );\n titlebar.appendChild(titletext);\n this.root.appendChild(titlebar);\n this.header = titletext;\n};\n\nmpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._init_canvas = function () {\n var fig = this;\n\n var canvas_div = (this.canvas_div = document.createElement('div'));\n canvas_div.setAttribute(\n 'style',\n 'border: 1px solid #ddd;' +\n 'box-sizing: content-box;' +\n 'clear: both;' +\n 'min-height: 1px;' +\n 'min-width: 1px;' +\n 'outline: 0;' +\n 'overflow: hidden;' +\n 'position: relative;' +\n 'resize: both;'\n );\n\n function on_keyboard_event_closure(name) {\n return function (event) {\n return fig.key_event(event, name);\n };\n }\n\n canvas_div.addEventListener(\n 'keydown',\n on_keyboard_event_closure('key_press')\n );\n canvas_div.addEventListener(\n 'keyup',\n on_keyboard_event_closure('key_release')\n );\n\n this._canvas_extra_style(canvas_div);\n this.root.appendChild(canvas_div);\n\n var canvas = (this.canvas = document.createElement('canvas'));\n canvas.classList.add('mpl-canvas');\n canvas.setAttribute('style', 'box-sizing: content-box;');\n\n this.context = canvas.getContext('2d');\n\n var backingStore =\n this.context.backingStorePixelRatio ||\n this.context.webkitBackingStorePixelRatio ||\n this.context.mozBackingStorePixelRatio ||\n this.context.msBackingStorePixelRatio ||\n this.context.oBackingStorePixelRatio ||\n this.context.backingStorePixelRatio ||\n 1;\n\n this.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n 'canvas'\n ));\n rubberband_canvas.setAttribute(\n 'style',\n 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n );\n\n // Apply a ponyfill if ResizeObserver is not implemented by browser.\n if (this.ResizeObserver === undefined) {\n if (window.ResizeObserver !== undefined) {\n this.ResizeObserver = window.ResizeObserver;\n } else {\n var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n this.ResizeObserver = obs.ResizeObserver;\n }\n }\n\n this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n var nentries = entries.length;\n for (var i = 0; i < nentries; i++) {\n var entry = entries[i];\n var width, height;\n if (entry.contentBoxSize) {\n if (entry.contentBoxSize instanceof Array) {\n // Chrome 84 implements new version of spec.\n width = entry.contentBoxSize[0].inlineSize;\n height = entry.contentBoxSize[0].blockSize;\n } else {\n // Firefox implements old version of spec.\n width = entry.contentBoxSize.inlineSize;\n height = entry.contentBoxSize.blockSize;\n }\n } else {\n // Chrome <84 implements even older version of spec.\n width = entry.contentRect.width;\n height = entry.contentRect.height;\n }\n\n // Keep the size of the canvas and rubber band canvas in sync with\n // the canvas container.\n if (entry.devicePixelContentBoxSize) {\n // Chrome 84 implements new version of spec.\n canvas.setAttribute(\n 'width',\n entry.devicePixelContentBoxSize[0].inlineSize\n );\n canvas.setAttribute(\n 'height',\n entry.devicePixelContentBoxSize[0].blockSize\n );\n } else {\n canvas.setAttribute('width', width * fig.ratio);\n canvas.setAttribute('height', height * fig.ratio);\n }\n canvas.setAttribute(\n 'style',\n 'width: ' + width + 'px; height: ' + height + 'px;'\n );\n\n rubberband_canvas.setAttribute('width', width);\n rubberband_canvas.setAttribute('height', height);\n\n // And update the size in Python. We ignore the initial 0/0 size\n // that occurs as the element is placed into the DOM, which should\n // otherwise not happen due to the minimum size styling.\n if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n fig.request_resize(width, height);\n }\n }\n });\n this.resizeObserverInstance.observe(canvas_div);\n\n function on_mouse_event_closure(name) {\n return function (event) {\n return fig.mouse_event(event, name);\n };\n }\n\n rubberband_canvas.addEventListener(\n 'mousedown',\n on_mouse_event_closure('button_press')\n );\n rubberband_canvas.addEventListener(\n 'mouseup',\n on_mouse_event_closure('button_release')\n );\n rubberband_canvas.addEventListener(\n 'dblclick',\n on_mouse_event_closure('dblclick')\n );\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband_canvas.addEventListener(\n 'mousemove',\n on_mouse_event_closure('motion_notify')\n );\n\n rubberband_canvas.addEventListener(\n 'mouseenter',\n on_mouse_event_closure('figure_enter')\n );\n rubberband_canvas.addEventListener(\n 'mouseleave',\n on_mouse_event_closure('figure_leave')\n );\n\n canvas_div.addEventListener('wheel', function (event) {\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n on_mouse_event_closure('scroll')(event);\n });\n\n canvas_div.appendChild(canvas);\n canvas_div.appendChild(rubberband_canvas);\n\n this.rubberband_context = rubberband_canvas.getContext('2d');\n this.rubberband_context.strokeStyle = '#000000';\n\n this._resize_canvas = function (width, height, forward) {\n if (forward) {\n canvas_div.style.width = width + 'px';\n canvas_div.style.height = height + 'px';\n }\n };\n\n // Disable right mouse context menu.\n this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n event.preventDefault();\n return false;\n });\n\n function set_focus() {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'mpl-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n continue;\n }\n\n var button = (fig.buttons[name] = document.createElement('button'));\n button.classList = 'mpl-widget';\n button.setAttribute('role', 'button');\n button.setAttribute('aria-disabled', 'false');\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n\n var icon_img = document.createElement('img');\n icon_img.src = '_images/' + image + '.png';\n icon_img.srcset = '_images/' + image + '_large.png 2x';\n icon_img.alt = tooltip;\n button.appendChild(icon_img);\n\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n var fmt_picker = document.createElement('select');\n fmt_picker.classList = 'mpl-widget';\n toolbar.appendChild(fmt_picker);\n this.format_dropdown = fmt_picker;\n\n for (var ind in mpl.extensions) {\n var fmt = mpl.extensions[ind];\n var option = document.createElement('option');\n option.selected = fmt === mpl.default_extension;\n option.innerHTML = fmt;\n fmt_picker.appendChild(option);\n }\n\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n};\n\nmpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n // which will in turn request a refresh of the image.\n this.send_message('resize', { width: x_pixels, height: y_pixels });\n};\n\nmpl.figure.prototype.send_message = function (type, properties) {\n properties['type'] = type;\n properties['figure_id'] = this.id;\n this.ws.send(JSON.stringify(properties));\n};\n\nmpl.figure.prototype.send_draw_message = function () {\n if (!this.waiting) {\n this.waiting = true;\n this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n var format_dropdown = fig.format_dropdown;\n var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n fig.ondownload(fig, format);\n};\n\nmpl.figure.prototype.handle_resize = function (fig, msg) {\n var size = msg['size'];\n if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n fig._resize_canvas(size[0], size[1], msg['forward']);\n fig.send_message('refresh', {});\n }\n};\n\nmpl.figure.prototype.handle_rubberband = function (fig, msg) {\n var x0 = msg['x0'] / fig.ratio;\n var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n var x1 = msg['x1'] / fig.ratio;\n var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n x0 = Math.floor(x0) + 0.5;\n y0 = Math.floor(y0) + 0.5;\n x1 = Math.floor(x1) + 0.5;\n y1 = Math.floor(y1) + 0.5;\n var min_x = Math.min(x0, x1);\n var min_y = Math.min(y0, y1);\n var width = Math.abs(x1 - x0);\n var height = Math.abs(y1 - y0);\n\n fig.rubberband_context.clearRect(\n 0,\n 0,\n fig.canvas.width / fig.ratio,\n fig.canvas.height / fig.ratio\n );\n\n fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n};\n\nmpl.figure.prototype.handle_figure_label = function (fig, msg) {\n // Updates the figure title.\n fig.header.textContent = msg['label'];\n};\n\nmpl.figure.prototype.handle_cursor = function (fig, msg) {\n fig.rubberband_canvas.style.cursor = msg['cursor'];\n};\n\nmpl.figure.prototype.handle_message = function (fig, msg) {\n fig.message.textContent = msg['message'];\n};\n\nmpl.figure.prototype.handle_draw = function (fig, _msg) {\n // Request the server to send over a new figure.\n fig.send_draw_message();\n};\n\nmpl.figure.prototype.handle_image_mode = function (fig, msg) {\n fig.image_mode = msg['mode'];\n};\n\nmpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n for (var key in msg) {\n if (!(key in fig.buttons)) {\n continue;\n }\n fig.buttons[key].disabled = !msg[key];\n fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n }\n};\n\nmpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n if (msg['mode'] === 'PAN') {\n fig.buttons['Pan'].classList.add('active');\n fig.buttons['Zoom'].classList.remove('active');\n } else if (msg['mode'] === 'ZOOM') {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.add('active');\n } else {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.remove('active');\n }\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Called whenever the canvas gets updated.\n this.send_message('ack', {});\n};\n\n// A function to construct a web socket function for onmessage handling.\n// Called in the figure constructor.\nmpl.figure.prototype._make_on_message_function = function (fig) {\n return function socket_on_message(evt) {\n if (evt.data instanceof Blob) {\n var img = evt.data;\n if (img.type !== 'image/png') {\n /* FIXME: We get \"Resource interpreted as Image but\n * transferred with MIME type text/plain:\" errors on\n * Chrome. But how to set the MIME type? It doesn't seem\n * to be part of the websocket stream */\n img.type = 'image/png';\n }\n\n /* Free the memory for the previous frames */\n if (fig.imageObj.src) {\n (window.URL || window.webkitURL).revokeObjectURL(\n fig.imageObj.src\n );\n }\n\n fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n img\n );\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n } else if (\n typeof evt.data === 'string' &&\n evt.data.slice(0, 21) === 'data:image/png;base64'\n ) {\n fig.imageObj.src = evt.data;\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n }\n\n var msg = JSON.parse(evt.data);\n var msg_type = msg['type'];\n\n // Call the \"handle_{type}\" callback, which takes\n // the figure and JSON message as its only arguments.\n try {\n var callback = fig['handle_' + msg_type];\n } catch (e) {\n console.log(\n \"No handler for the '\" + msg_type + \"' message type: \",\n msg\n );\n return;\n }\n\n if (callback) {\n try {\n // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n callback(fig, msg);\n } catch (e) {\n console.log(\n \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n e,\n e.stack,\n msg\n );\n }\n }\n };\n};\n\n// from https://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\nmpl.findpos = function (e) {\n //this section is from http://www.quirksmode.org/js/events_properties.html\n var targ;\n if (!e) {\n e = window.event;\n }\n if (e.target) {\n targ = e.target;\n } else if (e.srcElement) {\n targ = e.srcElement;\n }\n if (targ.nodeType === 3) {\n // defeat Safari bug\n targ = targ.parentNode;\n }\n\n // pageX,Y are the mouse positions relative to the document\n var boundingRect = targ.getBoundingClientRect();\n var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n\n return { x: x, y: y };\n};\n\n/*\n * return a copy of an object with only non-object keys\n * we need this to avoid circular references\n * https://stackoverflow.com/a/24161582/3208463\n */\nfunction simpleKeys(original) {\n return Object.keys(original).reduce(function (obj, key) {\n if (typeof original[key] !== 'object') {\n obj[key] = original[key];\n }\n return obj;\n }, {});\n}\n\nmpl.figure.prototype.mouse_event = function (event, name) {\n var canvas_pos = mpl.findpos(event);\n\n if (name === 'button_press') {\n this.canvas.focus();\n this.canvas_div.focus();\n }\n\n var x = canvas_pos.x * this.ratio;\n var y = canvas_pos.y * this.ratio;\n\n this.send_message(name, {\n x: x,\n y: y,\n button: event.button,\n step: event.step,\n guiEvent: simpleKeys(event),\n });\n\n /* This prevents the web browser from automatically changing to\n * the text insertion cursor when the button is pressed. We want\n * to control all of the cursor setting manually through the\n * 'cursor' event from matplotlib */\n event.preventDefault();\n return false;\n};\n\nmpl.figure.prototype._key_event_extra = function (_event, _name) {\n // Handle any extra behaviour associated with a key event\n};\n\nmpl.figure.prototype.key_event = function (event, name) {\n // Prevent repeat events\n if (name === 'key_press') {\n if (event.key === this._key) {\n return;\n } else {\n this._key = event.key;\n }\n }\n if (name === 'key_release') {\n this._key = null;\n }\n\n var value = '';\n if (event.ctrlKey && event.key !== 'Control') {\n value += 'ctrl+';\n }\n else if (event.altKey && event.key !== 'Alt') {\n value += 'alt+';\n }\n else if (event.shiftKey && event.key !== 'Shift') {\n value += 'shift+';\n }\n\n value += 'k' + event.key;\n\n this._key_event_extra(event, name);\n\n this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n return false;\n};\n\nmpl.figure.prototype.toolbar_button_onclick = function (name) {\n if (name === 'download') {\n this.handle_save(this, null);\n } else {\n this.send_message('toolbar_button', { name: name });\n }\n};\n\nmpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n this.message.textContent = tooltip;\n};\n\n///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n// prettier-ignore\nvar _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\nmpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n\nmpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n\nmpl.default_extension = \"png\";/* global mpl */\n\nvar comm_websocket_adapter = function (comm) {\n // Create a \"websocket\"-like object which calls the given IPython comm\n // object with the appropriate methods. Currently this is a non binary\n // socket, so there is still some room for performance tuning.\n var ws = {};\n\n ws.binaryType = comm.kernel.ws.binaryType;\n ws.readyState = comm.kernel.ws.readyState;\n function updateReadyState(_event) {\n if (comm.kernel.ws) {\n ws.readyState = comm.kernel.ws.readyState;\n } else {\n ws.readyState = 3; // Closed state.\n }\n }\n comm.kernel.ws.addEventListener('open', updateReadyState);\n comm.kernel.ws.addEventListener('close', updateReadyState);\n comm.kernel.ws.addEventListener('error', updateReadyState);\n\n ws.close = function () {\n comm.close();\n };\n ws.send = function (m) {\n //console.log('sending', m);\n comm.send(m);\n };\n // Register the callback with on_msg.\n comm.on_msg(function (msg) {\n //console.log('receiving', msg['content']['data'], msg);\n var data = msg['content']['data'];\n if (data['blob'] !== undefined) {\n data = {\n data: new Blob(msg['buffers'], { type: data['blob'] }),\n };\n }\n // Pass the mpl event to the overridden (by mpl) onmessage function.\n ws.onmessage(data);\n });\n return ws;\n};\n\nmpl.mpl_figure_comm = function (comm, msg) {\n // This is the function which gets called when the mpl process\n // starts-up an IPython Comm through the \"matplotlib\" channel.\n\n var id = msg.content.data.id;\n // Get hold of the div created by the display call when the Comm\n // socket was opened in Python.\n var element = document.getElementById(id);\n var ws_proxy = comm_websocket_adapter(comm);\n\n function ondownload(figure, _format) {\n window.open(figure.canvas.toDataURL());\n }\n\n var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n\n // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n // web socket which is closed, not our websocket->open comm proxy.\n ws_proxy.onopen();\n\n fig.parent_element = element;\n fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n if (!fig.cell_info) {\n console.error('Failed to find cell for figure', id, fig);\n return;\n }\n fig.cell_info[0].output_area.element.on(\n 'cleared',\n { fig: fig },\n fig._remove_fig_handler\n );\n};\n\nmpl.figure.prototype.handle_close = function (fig, msg) {\n var width = fig.canvas.width / fig.ratio;\n fig.cell_info[0].output_area.element.off(\n 'cleared',\n fig._remove_fig_handler\n );\n fig.resizeObserverInstance.unobserve(fig.canvas_div);\n\n // Update the output cell to use the data from the current canvas.\n fig.push_to_output();\n var dataURL = fig.canvas.toDataURL();\n // Re-enable the keyboard manager in IPython - without this line, in FF,\n // the notebook keyboard shortcuts fail.\n IPython.keyboard_manager.enable();\n fig.parent_element.innerHTML =\n '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n fig.close_ws(fig, msg);\n};\n\nmpl.figure.prototype.close_ws = function (fig, msg) {\n fig.send_message('closing', msg);\n // fig.ws.close()\n};\n\nmpl.figure.prototype.push_to_output = function (_remove_interactive) {\n // Turn the data on the canvas into data in the output cell.\n var width = this.canvas.width / this.ratio;\n var dataURL = this.canvas.toDataURL();\n this.cell_info[1]['text/html'] =\n '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Tell IPython that the notebook contents must change.\n IPython.notebook.set_dirty(true);\n this.send_message('ack', {});\n var fig = this;\n // Wait a second, then push the new image to the DOM so\n // that it is saved nicely (might be nice to debounce this).\n setTimeout(function () {\n fig.push_to_output();\n }, 1000);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'btn-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n var button;\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n continue;\n }\n\n button = fig.buttons[name] = document.createElement('button');\n button.classList = 'btn btn-default';\n button.href = '#';\n button.title = name;\n button.innerHTML = '<i class=\"fa ' + image + ' fa-lg\"></i>';\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n // Add the status bar.\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message pull-right';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n\n // Add the close button to the window.\n var buttongrp = document.createElement('div');\n buttongrp.classList = 'btn-group inline pull-right';\n button = document.createElement('button');\n button.classList = 'btn btn-mini btn-primary';\n button.href = '#';\n button.title = 'Stop Interaction';\n button.innerHTML = '<i class=\"fa fa-power-off icon-remove icon-large\"></i>';\n button.addEventListener('click', function (_evt) {\n fig.handle_close(fig, {});\n });\n button.addEventListener(\n 'mouseover',\n on_mouseover_closure('Stop Interaction')\n );\n buttongrp.appendChild(button);\n var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n titlebar.insertBefore(buttongrp, titlebar.firstChild);\n};\n\nmpl.figure.prototype._remove_fig_handler = function (event) {\n var fig = event.data.fig;\n if (event.target !== this) {\n // Ignore bubbled events from children.\n return;\n }\n fig.close_ws(fig, {});\n};\n\nmpl.figure.prototype._root_extra_style = function (el) {\n el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n};\n\nmpl.figure.prototype._canvas_extra_style = function (el) {\n // this is important to make the div 'focusable\n el.setAttribute('tabindex', 0);\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n } else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n};\n\nmpl.figure.prototype._key_event_extra = function (event, _name) {\n // Check for shift+enter\n if (event.shiftKey && event.which === 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n fig.ondownload(fig, null);\n};\n\nmpl.find_output_cell = function (html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i = 0; i < ncells; i++) {\n var cell = cells[i];\n if (cell.cell_type === 'code') {\n for (var j = 0; j < cell.output_area.outputs.length; j++) {\n var data = cell.output_area.outputs[j];\n if (data.data) {\n // IPython >= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] === html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n};\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel !== null) {\n IPython.notebook.kernel.comm_manager.register_target(\n 'matplotlib',\n mpl.mpl_figure_comm\n );\n}\n",
|
236 |
+
"text/plain": [
|
237 |
+
"<IPython.core.display.Javascript object>"
|
238 |
+
]
|
239 |
+
},
|
240 |
+
"metadata": {},
|
241 |
+
"output_type": "display_data"
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"data": {
|
245 |
+
"text/html": [
|
246 |
+
"<img src=\"\" width=\"640\">"
|
247 |
+
],
|
248 |
+
"text/plain": [
|
249 |
+
"<IPython.core.display.HTML object>"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
"metadata": {},
|
253 |
+
"output_type": "display_data"
|
254 |
+
}
|
255 |
+
],
|
256 |
+
"source": [
|
257 |
+
"anim = animate_episode(predictor_data[episode_id][\"solver_info\"], predictor_data[episode_id][\"ado_position_future\"], predictor_data[episode_id][\"ado_unbiased_predictions\"])\n",
|
258 |
+
"anim.save(os.path.join(stats_dir, f\"ep_{episode_id}_predictor.mp4\"), writer=\"ffmpeg\")"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": null,
|
264 |
+
"id": "39c17b02",
|
265 |
+
"metadata": {},
|
266 |
+
"outputs": [],
|
267 |
+
"source": []
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": null,
|
272 |
+
"id": "144aec04",
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": []
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "code",
|
279 |
+
"execution_count": null,
|
280 |
+
"id": "b5dbaa75",
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [],
|
283 |
+
"source": []
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": null,
|
288 |
+
"id": "a86d7330",
|
289 |
+
"metadata": {},
|
290 |
+
"outputs": [],
|
291 |
+
"source": []
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"cell_type": "code",
|
295 |
+
"execution_count": null,
|
296 |
+
"id": "da0e8629",
|
297 |
+
"metadata": {},
|
298 |
+
"outputs": [],
|
299 |
+
"source": []
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"id": "342b6e50",
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": []
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"id": "11cf1be2",
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": []
|
316 |
+
}
|
317 |
+
],
|
318 |
+
"metadata": {
|
319 |
+
"kernelspec": {
|
320 |
+
"display_name": "Python 3 (ipykernel)",
|
321 |
+
"language": "python",
|
322 |
+
"name": "python3"
|
323 |
+
},
|
324 |
+
"language_info": {
|
325 |
+
"codemirror_mode": {
|
326 |
+
"name": "ipython",
|
327 |
+
"version": 3
|
328 |
+
},
|
329 |
+
"file_extension": ".py",
|
330 |
+
"mimetype": "text/x-python",
|
331 |
+
"name": "python",
|
332 |
+
"nbconvert_exporter": "python",
|
333 |
+
"pygments_lexer": "ipython3",
|
334 |
+
"version": "3.7.13"
|
335 |
+
},
|
336 |
+
"vscode": {
|
337 |
+
"interpreter": {
|
338 |
+
"hash": "a1479098f155d52cde87e35cb1613d4d825087d81bb03677a9d084ad747a84cc"
|
339 |
+
}
|
340 |
+
}
|
341 |
+
},
|
342 |
+
"nbformat": 4,
|
343 |
+
"nbformat_minor": 5
|
344 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.12
|
2 |
+
matplotlib
|
3 |
+
numpy
|
4 |
+
mmcv>=1.4.7
|
5 |
+
pytorch-lightning
|
6 |
+
pytest
|
7 |
+
setuptools>=59.5.0
|
8 |
+
wandb
|
9 |
+
plotly
|
10 |
+
scipy
|
11 |
+
gradio>=3.7
|
12 |
+
datasets
|
13 |
+
huggingface_hub
|
14 |
+
einops
|
15 |
+
pydantic
|
16 |
+
fire
|
risk_biased/__init__.py
ADDED
File without changes
|
risk_biased/config/learning_config.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from risk_biased.config.paths import (
|
2 |
+
log_path,
|
3 |
+
)
|
4 |
+
|
5 |
+
# WandB Project Name
|
6 |
+
project = "RiskBiased"
|
7 |
+
entity = "tri"
|
8 |
+
|
9 |
+
# Scene Parameters
|
10 |
+
dt = 0.1
|
11 |
+
time_scene = 5.0
|
12 |
+
sample_times = [t * dt for t in range(0, int(time_scene / dt))]
|
13 |
+
ego_ref_speed = 14.0
|
14 |
+
ego_length = 4.0
|
15 |
+
ego_width = 1.75
|
16 |
+
fast_speed = 2.0
|
17 |
+
slow_speed = 1.0
|
18 |
+
p_change_pace = 0.2
|
19 |
+
proportion_fast = 0.5
|
20 |
+
|
21 |
+
# Data Parameters
|
22 |
+
file_name = "scene_data"
|
23 |
+
datasets_sizes = {"train": 100000, "val": 10000, "test": 30000}
|
24 |
+
datasets = list(datasets_sizes.keys())
|
25 |
+
state_dim = 2
|
26 |
+
dynamic_state_dim = 2
|
27 |
+
num_steps = 5
|
28 |
+
num_steps_future = len(sample_times) - num_steps
|
29 |
+
ego_speed_init_low = 4.0
|
30 |
+
ego_speed_init_high = 16.0
|
31 |
+
ego_acceleration_mean_low = -1.5
|
32 |
+
ego_acceleration_mean_high = 1.5
|
33 |
+
ego_acceleration_std = 3.0
|
34 |
+
perception_noise_std = 0.05
|
35 |
+
map_state_dim = 0
|
36 |
+
max_size_lane = 0
|
37 |
+
num_blocks = 3
|
38 |
+
interaction_type = None
|
39 |
+
mcg_dim_expansion = 0
|
40 |
+
mcg_num_layers = 0
|
41 |
+
num_attention_heads = 4
|
42 |
+
|
43 |
+
|
44 |
+
# Model Hyperparameters
|
45 |
+
model_type = "encoder_biased"
|
46 |
+
condition_on_ego_future = True
|
47 |
+
latent_dim = 2
|
48 |
+
hidden_dim = 64
|
49 |
+
num_vq = 256
|
50 |
+
latent_distribution = "gaussian" # "gaussian" or "quantized"
|
51 |
+
num_hidden_layers = 3
|
52 |
+
sequence_encoder_type = "MLP" # one of "MLP", "LSTM", "maskedLSTM"
|
53 |
+
sequence_decoder_type = "MLP" # one of "MLP", "LSTM", "maskedLSTM"
|
54 |
+
is_mlp_residual = True
|
55 |
+
|
56 |
+
# Variational Loss Hyperparameters
|
57 |
+
kl_weight = 0.5 # For the didactic example with gaussian latent kl_weight = 0.5 is a good value, with quantized latent kl_weight = 0.1 is a good value
|
58 |
+
kl_threshold = 0.1
|
59 |
+
latent_regularization = 0.1
|
60 |
+
|
61 |
+
# Risk distribution should be one of the following types :
|
62 |
+
# {"type": "uniform", "min": 0, "max": 1},
|
63 |
+
# {"type": "normal", "mean": 0, "sigma": 1},
|
64 |
+
# {"type": "bernoulli", "p": 0.5, "min": 0, "max": 1},
|
65 |
+
# {"type": "beta", "alpha": 2, "beta": 5, "min": 0, "max": 1},
|
66 |
+
# {"type": "chi2", "k": 3, "min": 0, "scale": 1},
|
67 |
+
# {"type": "log-normal", "mu": 0, "sigma": 1, "min": 0, "scale": 1}
|
68 |
+
# {"type": "log-uniform", "min": 0, "max": 1, "scale": 1}
|
69 |
+
risk_distribution = {"type": "log-uniform", "min": 0, "max": 1, "scale": 3}
|
70 |
+
|
71 |
+
|
72 |
+
# Monte Carlo risk estimator should be one of the following types :
|
73 |
+
# {"type": "entropic", "eps": 1e-4}
|
74 |
+
# {"type": "cvar", "eps": 1e-4}
|
75 |
+
|
76 |
+
risk_estimator = {"type": "cvar", "eps": 1e-3}
|
77 |
+
if latent_distribution == "quantized":
|
78 |
+
# Number of samples used to estimate the risk from the unbiased distribution
|
79 |
+
n_mc_samples_risk = num_vq
|
80 |
+
# Number of samples used to estimate the averaged cost of the biased distribution
|
81 |
+
n_mc_samples_biased = num_vq
|
82 |
+
else:
|
83 |
+
# Number of samples used to estimate the risk from the unbiased distribution
|
84 |
+
n_mc_samples_risk = 512
|
85 |
+
# Number of samples used to estimate the averaged cost of the biased distribution
|
86 |
+
n_mc_samples_biased = 256
|
87 |
+
|
88 |
+
|
89 |
+
# Risk Loss Hyperparameters
|
90 |
+
risk_weight = 1
|
91 |
+
risk_assymetry_factor = 200
|
92 |
+
use_risk_constraint = True # For encoder_biased only
|
93 |
+
risk_constraint_update_every_n_epoch = (
|
94 |
+
1 # For encoder_biased only, not used if use_risk_constraint == False
|
95 |
+
)
|
96 |
+
risk_constraint_weight_update_factor = (
|
97 |
+
1.5 # For encoder_biased only, not used if use_risk_constraint == False
|
98 |
+
)
|
99 |
+
risk_constraint_weight_maximum = (
|
100 |
+
1e5 # For encoder_biased only, not used if use_risk_constraint == False
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
# Training Hyperparameters
|
105 |
+
learning_rate = 1e-4
|
106 |
+
batch_size = 512
|
107 |
+
num_epochs_cvae = 100
|
108 |
+
num_epochs_bias = 100
|
109 |
+
gpus = [0]
|
110 |
+
seed = 0 # Give an integer value to seed will set seed for pseudo-random number generators in: pytorch, numpy, python.random
|
111 |
+
early_stopping = False
|
112 |
+
accumulate_grad_batches = 1
|
113 |
+
|
114 |
+
num_workers = 4
|
115 |
+
log_weights_and_grads = False
|
116 |
+
num_samples_min_fde = 16
|
117 |
+
val_check_interval_epoch = 1
|
118 |
+
plot_interval_epoch = 1
|
119 |
+
histogram_interval_epoch = 1
|
120 |
+
|
121 |
+
# State Cost Hyperparameters
|
122 |
+
cost_scale = 10
|
123 |
+
cost_reduce = (
|
124 |
+
"mean" # choose in "discounted_mean", "mean", "min", "max", "now", "final"
|
125 |
+
)
|
126 |
+
discount_factor = 0.95 # only used if cost_reduce == "discounted_mean", discounts the cost by this factor at each time step
|
127 |
+
distance_bandwidth = 2
|
128 |
+
time_bandwidth = 0.5
|
129 |
+
min_velocity_diff = 0.03
|
130 |
+
|
131 |
+
|
132 |
+
# List all above parameters that make a difference in the dataset to distringuish datasets once generated
|
133 |
+
dataset_parameters = {
|
134 |
+
"dt": dt,
|
135 |
+
"time_scene": time_scene,
|
136 |
+
"sample_times": sample_times,
|
137 |
+
"ego_ref_speed": ego_ref_speed,
|
138 |
+
"ego_speed_init_low": ego_speed_init_low,
|
139 |
+
"ego_speed_init_high": ego_speed_init_high,
|
140 |
+
"ego_acceleration_mean_low": ego_acceleration_mean_low,
|
141 |
+
"ego_acceleration_mean_high": ego_acceleration_mean_high,
|
142 |
+
"ego_acceleration_std": ego_acceleration_std,
|
143 |
+
"fast_speed": fast_speed,
|
144 |
+
"slow_speed": slow_speed,
|
145 |
+
"p_change_pace": p_change_pace,
|
146 |
+
"proportion_fast": proportion_fast,
|
147 |
+
"file_name": file_name,
|
148 |
+
"datasets_sizes": datasets_sizes,
|
149 |
+
"state_dim": state_dim,
|
150 |
+
"num_steps": num_steps,
|
151 |
+
"num_steps_future": num_steps_future,
|
152 |
+
"perception_noise_std": perception_noise_std,
|
153 |
+
}
|
154 |
+
|
155 |
+
# List files that should be saved as log
|
156 |
+
files_to_log = []
|
risk_biased/config/paths.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data split path:
|
2 |
+
base_path = "<set your path to the pre-processed data directory here>"
|
3 |
+
data_dir = "interactive_veh_type" # This directory name conditions the model hyperparameters, make sure to set it correctly
|
4 |
+
sample_dataset_path = base_path + data_dir + "/sample"
|
5 |
+
val_dataset_path = base_path + data_dir + "/validation"
|
6 |
+
train_dataset_path = base_path + data_dir + "/training"
|
7 |
+
test_dataset_path = base_path + data_dir + "/sample"
|
8 |
+
|
9 |
+
log_path = "<set your path to any directory where you want the logs to be stored>"
|
risk_biased/config/planning_config.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tracking Cost Parameters
|
2 |
+
tracking_cost_scale_longitudinal = 1e-2
|
3 |
+
tracking_cost_scale_lateral = 1e-0
|
4 |
+
tracking_cost_reduce = "mean"
|
5 |
+
|
6 |
+
# Cross Entropy Solver Parameters
|
7 |
+
num_control_samples = 100
|
8 |
+
num_elite = 30
|
9 |
+
iter_max = 10
|
10 |
+
smoothing_factor = 0.2
|
11 |
+
mean_warm_start = True
|
12 |
+
acceleration_std_x_m_s2 = 2.0
|
13 |
+
acceleration_std_y_m_s2 = 0.0
|
risk_biased/config/waymo_config.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from risk_biased.config.paths import (
|
2 |
+
data_dir,
|
3 |
+
sample_dataset_path,
|
4 |
+
val_dataset_path,
|
5 |
+
train_dataset_path,
|
6 |
+
test_dataset_path,
|
7 |
+
log_path,
|
8 |
+
)
|
9 |
+
|
10 |
+
# Data augmentation:
|
11 |
+
normalize_angle = True
|
12 |
+
random_rotation = False
|
13 |
+
angle_std = 3.14 / 4
|
14 |
+
random_translation = False
|
15 |
+
translation_distance_std = 0.1
|
16 |
+
p_exchange_two_first = 0.5
|
17 |
+
|
18 |
+
# Data diminution:
|
19 |
+
min_num_observation = 2
|
20 |
+
max_size_lane = 50
|
21 |
+
train_dataset_size_limit = None
|
22 |
+
val_dataset_size_limit = None
|
23 |
+
max_num_agents = 50
|
24 |
+
max_num_objects = 50
|
25 |
+
|
26 |
+
# Data caracterization:
|
27 |
+
time_scene = 9.1
|
28 |
+
dt = 0.1
|
29 |
+
num_steps = 11
|
30 |
+
num_steps_future = 80
|
31 |
+
|
32 |
+
# TODO: avoid conditioning on the name of the directory in the path
|
33 |
+
if data_dir == "interactive_veh_type":
|
34 |
+
map_state_dim = 2 + num_steps * 8
|
35 |
+
state_dim = 11
|
36 |
+
dynamic_state_dim = 5
|
37 |
+
elif data_dir == "interactive_full":
|
38 |
+
map_state_dim = 2
|
39 |
+
state_dim = 5
|
40 |
+
dynamic_state_dim = 5
|
41 |
+
else:
|
42 |
+
map_state_dim = 2
|
43 |
+
state_dim = 2
|
44 |
+
dynamic_state_dim = 2
|
45 |
+
|
46 |
+
# Variational Loss Hyperparameters
|
47 |
+
kl_weight = 1.0
|
48 |
+
kl_threshold = 0.01
|
49 |
+
|
50 |
+
# Training Parameters
|
51 |
+
learning_rate = 3e-4
|
52 |
+
batch_size = 64
|
53 |
+
accumulate_grad_batches = 2
|
54 |
+
num_epochs_cvae = 0
|
55 |
+
num_epochs_bias = 100
|
56 |
+
gpus = [1]
|
57 |
+
seed = 0 # Give an integer value to seed will set seed for pseudo-random number generators in: pytorch, numpy, python.random
|
58 |
+
num_workers = 8
|
59 |
+
|
60 |
+
# Model hyperparameter
|
61 |
+
model_type = "interaction_biased"
|
62 |
+
condition_on_ego_future = False
|
63 |
+
latent_dim = 16
|
64 |
+
hidden_dim = 128
|
65 |
+
feature_dim = 16
|
66 |
+
num_vq = 512
|
67 |
+
latent_distribution = "gaussian" # "gaussian" or "quantized"
|
68 |
+
is_mlp_residual = True
|
69 |
+
num_hidden_layers = 3
|
70 |
+
num_blocks = 3
|
71 |
+
interaction_type = "Attention" # one of "ContextGating", "Attention", "Hybrid"
|
72 |
+
## MCG parameters
|
73 |
+
mcg_dim_expansion = 2
|
74 |
+
mcg_num_layers = 0
|
75 |
+
## Attention parameters
|
76 |
+
num_attention_heads = 4
|
77 |
+
sequence_encoder_type = "MLP" # one of "MLP", "LSTM", "maskedLSTM"
|
78 |
+
sequence_decoder_type = "MLP" # one of "MLP", "LSTM"
|
79 |
+
|
80 |
+
|
81 |
+
# Risk Loss Hyperparameters
|
82 |
+
cost_reduce = "discounted_mean" # choose in "discounted_mean", "mean", "min", "max", "now", "final"
|
83 |
+
discount_factor = 0.95 # only used if cost_reduce == "discounted_mean", discounts the cost by this factor at each time step
|
84 |
+
min_velocity_diff = 0.1
|
85 |
+
n_mc_samples_risk = 32
|
86 |
+
n_mc_samples_biased = 16
|
87 |
+
risk_weight = 1
|
88 |
+
risk_assymetry_factor = 30
|
89 |
+
use_risk_constraint = True # For encoder_biased only
|
90 |
+
risk_constraint_update_every_n_epoch = (
|
91 |
+
1 # For encoder_biased only, not used if use_risk_constraint == False
|
92 |
+
)
|
93 |
+
risk_constraint_weight_update_factor = (
|
94 |
+
1.5 # For encoder_biased only, not used if use_risk_constraint == False
|
95 |
+
)
|
96 |
+
risk_constraint_weight_maximum = (
|
97 |
+
1000 # For encoder_biased only, not used if use_risk_constraint == False
|
98 |
+
)
|
99 |
+
|
100 |
+
# List files that should be saved as log
|
101 |
+
files_to_log = [
|
102 |
+
"./risk_biased/models/biased_cvae_model.py",
|
103 |
+
"./risk_biased/models/latent_distributions.py",
|
104 |
+
]
|
risk_biased/models/__init__.py
ADDED
File without changes
|
risk_biased/models/biased_cvae_model.py
ADDED
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from warnings import warn
|
2 |
+
from typing import Callable, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from risk_biased.models.map_encoder import MapEncoderNN
|
9 |
+
from risk_biased.models.mlp import MLP
|
10 |
+
from risk_biased.models.cvae_params import CVAEParams
|
11 |
+
from risk_biased.models.cvae_encoders import (
|
12 |
+
AbstractLatentDistribution,
|
13 |
+
CVAEEncoder,
|
14 |
+
BiasedEncoderNN,
|
15 |
+
FutureEncoderNN,
|
16 |
+
InferenceEncoderNN,
|
17 |
+
)
|
18 |
+
from risk_biased.models.cvae_decoder import (
|
19 |
+
CVAEAccelerationDecoder,
|
20 |
+
CVAEParametrizedDecoder,
|
21 |
+
DecoderNN,
|
22 |
+
)
|
23 |
+
from risk_biased.utils.cost import BaseCostTorch, get_cost
|
24 |
+
from risk_biased.utils.loss import (
|
25 |
+
reconstruction_loss,
|
26 |
+
risk_loss_function,
|
27 |
+
)
|
28 |
+
from risk_biased.models.latent_distributions import (
|
29 |
+
GaussianLatentDistribution,
|
30 |
+
QuantizedDistributionCreator,
|
31 |
+
AbstractLatentDistribution,
|
32 |
+
)
|
33 |
+
from risk_biased.utils.metrics import FDE, minFDE
|
34 |
+
from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator
|
35 |
+
|
36 |
+
|
37 |
+
class InferenceBiasedCVAE(nn.Module):
|
38 |
+
"""CVAE with a biased encoder module for risk-biased trajectory forecasting.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
absolute_encoder: encoder model for the absolute positions of the agents
|
42 |
+
map_encoder: encoder model for map objects
|
43 |
+
biased_encoder: biased encoder that uses past and auxiliary input,
|
44 |
+
inference_encoder: inference encoder that uses only past,
|
45 |
+
decoder: CVAE decoder model
|
46 |
+
prior_distribution: prior distribution for the latent space.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
absolute_encoder: MLP,
|
52 |
+
map_encoder: MapEncoderNN,
|
53 |
+
biased_encoder: CVAEEncoder,
|
54 |
+
inference_encoder: CVAEEncoder,
|
55 |
+
decoder: CVAEAccelerationDecoder,
|
56 |
+
prior_distribution: AbstractLatentDistribution,
|
57 |
+
) -> None:
|
58 |
+
super().__init__()
|
59 |
+
self.biased_encoder = biased_encoder
|
60 |
+
self.inference_encoder = inference_encoder
|
61 |
+
self.decoder = decoder
|
62 |
+
self.map_encoder = map_encoder
|
63 |
+
self.absolute_encoder = absolute_encoder
|
64 |
+
self.prior_distribution = prior_distribution
|
65 |
+
|
66 |
+
def cvae_parameters(self, recurse: bool = True):
|
67 |
+
"""Define an iterator over all the parameters related to the cvae."""
|
68 |
+
yield from self.absolute_encoder.parameters(recurse=recurse)
|
69 |
+
yield from self.map_encoder.parameters(recurse=recurse)
|
70 |
+
yield from self.inference_encoder.parameters(recurse=recurse)
|
71 |
+
yield from self.decoder.parameters(recurse=recurse)
|
72 |
+
|
73 |
+
def biased_parameters(self, recurse: bool = True):
|
74 |
+
"""Define an iterator over only the parameters related to the biaser."""
|
75 |
+
yield from self.biased_encoder.biased_parameters(recurse=recurse)
|
76 |
+
|
77 |
+
def forward(
|
78 |
+
self,
|
79 |
+
x: torch.Tensor,
|
80 |
+
mask_x: torch.Tensor,
|
81 |
+
map: torch.Tensor,
|
82 |
+
mask_map: torch.Tensor,
|
83 |
+
offset: torch.Tensor,
|
84 |
+
*,
|
85 |
+
x_ego: Optional[torch.Tensor] = None,
|
86 |
+
y_ego: Optional[torch.Tensor] = None,
|
87 |
+
risk_level: Optional[torch.Tensor] = None,
|
88 |
+
n_samples: int = 0,
|
89 |
+
) -> Tuple[torch.Tensor, AbstractLatentDistribution]:
|
90 |
+
"""Forward function that outputs a noisy reconstruction of y and parameters of latent
|
91 |
+
posterior distribution
|
92 |
+
|
93 |
+
Args:
|
94 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
95 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
96 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
97 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
|
98 |
+
offset : (batch_size, num_agents, state_dim) offset position from ego. Defaults to None.
|
99 |
+
x_ego: (batch_size, 1, num_steps, state_dim) ego history
|
100 |
+
y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
|
101 |
+
risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
|
102 |
+
trajectories. Defaults to None.
|
103 |
+
n_samples (optional): number of samples to predict, (if 0 one sample with no extra
|
104 |
+
dimension). Defaults to 0.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
noisy reconstruction y of size (batch_size, num_agents, num_steps_future, state_dim), as well as
|
108 |
+
weights of the samples and the latent distribution.
|
109 |
+
No bias is applied to encoder without offset or risk.
|
110 |
+
"""
|
111 |
+
|
112 |
+
encoded_map = self.map_encoder(map, mask_map)
|
113 |
+
mask_map = mask_map.any(-1)
|
114 |
+
encoded_absolute = self.absolute_encoder(offset)
|
115 |
+
|
116 |
+
if risk_level is not None:
|
117 |
+
biased_latent_distribution = self.biased_encoder(
|
118 |
+
x,
|
119 |
+
mask_x,
|
120 |
+
encoded_absolute,
|
121 |
+
encoded_map,
|
122 |
+
mask_map,
|
123 |
+
x_ego=x_ego,
|
124 |
+
y_ego=y_ego,
|
125 |
+
offset=offset,
|
126 |
+
risk_level=risk_level,
|
127 |
+
)
|
128 |
+
inference_latent_distribution = self.inference_encoder(
|
129 |
+
x,
|
130 |
+
mask_x,
|
131 |
+
encoded_absolute,
|
132 |
+
encoded_map,
|
133 |
+
mask_map,
|
134 |
+
)
|
135 |
+
latent_distribution = inference_latent_distribution.average(
|
136 |
+
biased_latent_distribution, risk_level.unsqueeze(-1)
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
latent_distribution = self.inference_encoder(
|
140 |
+
x,
|
141 |
+
mask_x,
|
142 |
+
encoded_absolute,
|
143 |
+
encoded_map,
|
144 |
+
mask_map,
|
145 |
+
)
|
146 |
+
z_sample, weights = latent_distribution.sample(n_samples=n_samples)
|
147 |
+
|
148 |
+
mask_z = mask_x.any(-1)
|
149 |
+
y_sample = self.decoder(
|
150 |
+
z_sample, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map, offset
|
151 |
+
)
|
152 |
+
|
153 |
+
return y_sample, weights, latent_distribution
|
154 |
+
|
155 |
+
def decode(
|
156 |
+
self,
|
157 |
+
z_samples: torch.Tensor,
|
158 |
+
mask_z: torch.Tensor,
|
159 |
+
x: torch.Tensor,
|
160 |
+
mask_x: torch.Tensor,
|
161 |
+
map: torch.Tensor,
|
162 |
+
mask_map: torch.Tensor,
|
163 |
+
offset: torch.Tensor,
|
164 |
+
):
|
165 |
+
"""Returns predicted y values conditionned on z_samples and the other observations.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of latent samples
|
169 |
+
mask_z: (batch_size, num_agents) bool mask
|
170 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
171 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
172 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
173 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
|
174 |
+
offset : (batch_size, num_agents, state_dim) offset position from ego.
|
175 |
+
"""
|
176 |
+
encoded_map = self.map_encoder(map, mask_map)
|
177 |
+
mask_map = mask_map.any(-1)
|
178 |
+
encoded_absolute = self.absolute_encoder(offset)
|
179 |
+
|
180 |
+
return self.decoder(
|
181 |
+
z_samples=z_samples,
|
182 |
+
mask_z=mask_z,
|
183 |
+
x=x,
|
184 |
+
mask_x=mask_x,
|
185 |
+
encoded_absolute=encoded_absolute,
|
186 |
+
encoded_map=encoded_map,
|
187 |
+
mask_map=mask_map,
|
188 |
+
offset=offset,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
class TrainingBiasedCVAE(InferenceBiasedCVAE):
|
193 |
+
|
194 |
+
"""CVAE with a biased encoder module for risk-biased trajectory forecasting.
|
195 |
+
This module is as a non-sampling-based version of BiasedLatentCVAE.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
absolute_encoder: encoder model for the absolute positions of the agents
|
199 |
+
map_encoder: encoder model for map objects
|
200 |
+
biased_encoder: biased encoder that uses past and auxiliary input,
|
201 |
+
inference_encoder: inference encoder that uses only past,
|
202 |
+
decoder: CVAE decoder model
|
203 |
+
future_encoder: training encoder that uses past and future,
|
204 |
+
cost_function: cost function used to compute the risk objective
|
205 |
+
risk_estimator: risk estimator used to compute the risk objective
|
206 |
+
prior_distribution: prior distribution for the latent space.
|
207 |
+
training_mode (optional): set to "cvae" to train the unbiased model, set to "bias" to train
|
208 |
+
the biased encoder. Defaults to "cvae".
|
209 |
+
latent_regularization (optional): regularization term for the latent space. Defaults to 0.
|
210 |
+
risk_assymetry_factor (optional): risk asymmetry factor used to compute the risk objective avoiding underestimations.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
absolute_encoder: MLP,
|
216 |
+
map_encoder: MapEncoderNN,
|
217 |
+
biased_encoder: CVAEEncoder,
|
218 |
+
inference_encoder: CVAEEncoder,
|
219 |
+
decoder: CVAEAccelerationDecoder,
|
220 |
+
future_encoder: CVAEEncoder,
|
221 |
+
cost_function: BaseCostTorch,
|
222 |
+
risk_estimator: AbstractMonteCarloRiskEstimator,
|
223 |
+
prior_distribution: AbstractLatentDistribution,
|
224 |
+
training_mode: str = "cvae",
|
225 |
+
latent_regularization: float = 0.0,
|
226 |
+
risk_assymetry_factor: float = 100.0,
|
227 |
+
) -> None:
|
228 |
+
super().__init__(
|
229 |
+
absolute_encoder,
|
230 |
+
map_encoder,
|
231 |
+
biased_encoder,
|
232 |
+
inference_encoder,
|
233 |
+
decoder,
|
234 |
+
prior_distribution,
|
235 |
+
)
|
236 |
+
self.future_encoder = future_encoder
|
237 |
+
self._cost = cost_function
|
238 |
+
self._risk = risk_estimator
|
239 |
+
self.set_training_mode(training_mode)
|
240 |
+
self.regularization_factor = latent_regularization
|
241 |
+
self.risk_assymetry_factor = risk_assymetry_factor
|
242 |
+
|
243 |
+
def cvae_parameters(self, recurse: bool = True):
|
244 |
+
yield from super().cvae_parameters(recurse)
|
245 |
+
yield from self.future_encoder.parameters(recurse)
|
246 |
+
|
247 |
+
def get_parameters(self, recurse: bool = True):
|
248 |
+
"""Returns a list of two parameter iterators: cvae and encoder only."""
|
249 |
+
return [
|
250 |
+
self.cvae_parameters(recurse),
|
251 |
+
self.biased_parameters(recurse),
|
252 |
+
]
|
253 |
+
|
254 |
+
def set_training_mode(self, training_mode: str) -> None:
|
255 |
+
"""
|
256 |
+
Change the training mode (get_loss function will be different depending on the mode).
|
257 |
+
|
258 |
+
Warning: This does not freeze the decoder because the gradient must pass through it.
|
259 |
+
The decoder should be frozen at the optimizer level when changing mode.
|
260 |
+
"""
|
261 |
+
assert training_mode in ["cvae", "bias"]
|
262 |
+
self.training_mode = training_mode
|
263 |
+
if training_mode == "cvae":
|
264 |
+
self.get_loss = self.get_loss_cvae
|
265 |
+
else:
|
266 |
+
self.get_loss = self.get_loss_biased
|
267 |
+
|
268 |
+
def forward_future(
|
269 |
+
self,
|
270 |
+
x: torch.Tensor,
|
271 |
+
mask_x: torch.Tensor,
|
272 |
+
map: torch.Tensor,
|
273 |
+
mask_map: torch.Tensor,
|
274 |
+
y: torch.Tensor,
|
275 |
+
mask_y: torch.Tensor,
|
276 |
+
offset: torch.Tensor,
|
277 |
+
return_inference: bool = False,
|
278 |
+
) -> Union[
|
279 |
+
Tuple[torch.Tensor, AbstractLatentDistribution],
|
280 |
+
Tuple[torch.Tensor, AbstractLatentDistribution, AbstractLatentDistribution],
|
281 |
+
]:
|
282 |
+
"""Forward function that outputs a noisy reconstruction of y and parameters of latent
|
283 |
+
posterior distribution
|
284 |
+
|
285 |
+
Args:
|
286 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
287 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
288 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
289 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
|
290 |
+
y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
291 |
+
mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask.
|
292 |
+
offset: (batch_size, num_agents, state_dim) offset position from ego.
|
293 |
+
return_inference: (optional) Set to true if z_mean_inference and z_log_std_inference should be returned, Defaults to None.
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
noisy reconstruction y of size (batch_size, num_agents, num_steps_future, state_dim), and the
|
297 |
+
distribution of the latent posterior, as well as, optionally, the distribution of the latent inference posterior.
|
298 |
+
"""
|
299 |
+
|
300 |
+
encoded_map = self.map_encoder(map, mask_map)
|
301 |
+
mask_map = mask_map.any(-1)
|
302 |
+
encoded_absolute = self.absolute_encoder(offset)
|
303 |
+
|
304 |
+
latent_distribution = self.future_encoder(
|
305 |
+
x,
|
306 |
+
mask_x,
|
307 |
+
y=y,
|
308 |
+
mask_y=mask_y,
|
309 |
+
encoded_absolute=encoded_absolute,
|
310 |
+
encoded_map=encoded_map,
|
311 |
+
mask_map=mask_map,
|
312 |
+
)
|
313 |
+
z_sample, weights = latent_distribution.sample()
|
314 |
+
mask_z = mask_x.any(-1)
|
315 |
+
|
316 |
+
y_sample = self.decoder(
|
317 |
+
z_sample,
|
318 |
+
mask_z,
|
319 |
+
x,
|
320 |
+
mask_x,
|
321 |
+
encoded_absolute,
|
322 |
+
encoded_map,
|
323 |
+
mask_map,
|
324 |
+
offset,
|
325 |
+
)
|
326 |
+
|
327 |
+
if return_inference:
|
328 |
+
inference_distribution = self.inference_encoder(
|
329 |
+
x,
|
330 |
+
mask_x,
|
331 |
+
encoded_absolute,
|
332 |
+
encoded_map,
|
333 |
+
mask_map,
|
334 |
+
)
|
335 |
+
|
336 |
+
return (
|
337 |
+
y_sample,
|
338 |
+
latent_distribution,
|
339 |
+
inference_distribution,
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
return y_sample, latent_distribution
|
343 |
+
|
344 |
+
def get_loss_cvae(
|
345 |
+
self,
|
346 |
+
x: torch.Tensor,
|
347 |
+
mask_x: torch.Tensor,
|
348 |
+
map: torch.Tensor,
|
349 |
+
mask_map: torch.Tensor,
|
350 |
+
y: torch.Tensor,
|
351 |
+
*,
|
352 |
+
mask_y: torch.Tensor,
|
353 |
+
mask_loss: torch.Tensor,
|
354 |
+
offset: torch.Tensor,
|
355 |
+
unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
356 |
+
kl_weight: float,
|
357 |
+
kl_threshold: float,
|
358 |
+
**kwargs,
|
359 |
+
) -> Tuple[torch.Tensor, dict]:
|
360 |
+
"""Compute and return risk-biased CVAE loss averaged over batch and sequence time steps,
|
361 |
+
along with desired loss-related metrics for logging
|
362 |
+
|
363 |
+
Args:
|
364 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
365 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
366 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
367 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
|
368 |
+
y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
369 |
+
mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask.
|
370 |
+
mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
|
371 |
+
should be computed and to False where it shouldn't
|
372 |
+
offset : (batch_size, num_agents, state_dim) offset position from ego.
|
373 |
+
unnormalizer: function that takes in a trajectory and an offset and that outputs the
|
374 |
+
unnormalized trajectory
|
375 |
+
kl_weight: weight to apply to the KL loss (normal value is 1.0, larger values can be
|
376 |
+
used for disentanglement)
|
377 |
+
kl_threshold: minimum float value threshold applied to the KL loss
|
378 |
+
|
379 |
+
Returns:
|
380 |
+
torch.Tensor: (1,) loss tensor
|
381 |
+
dict: dict that contains loss-related metrics to be logged
|
382 |
+
"""
|
383 |
+
log_dict = dict()
|
384 |
+
|
385 |
+
if not mask_loss.any():
|
386 |
+
warn("A batch is dropped because the whole loss is masked.")
|
387 |
+
return torch.zeros(1, requires_grad=True), {}
|
388 |
+
|
389 |
+
mask_z = mask_x.any(-1)
|
390 |
+
# sum_mask_z = mask_z.float().sum().clamp_min(1)
|
391 |
+
|
392 |
+
(y_sample, latent_distribution, inference_distribution) = self.forward_future(
|
393 |
+
x,
|
394 |
+
mask_x,
|
395 |
+
map,
|
396 |
+
mask_map,
|
397 |
+
y,
|
398 |
+
mask_y,
|
399 |
+
offset,
|
400 |
+
return_inference=True,
|
401 |
+
)
|
402 |
+
|
403 |
+
# sum_mask_z *= latent_distribution.mu.shape[-1]
|
404 |
+
|
405 |
+
# log_dict["latent/abs_mean"] = (
|
406 |
+
# (latent_distribution.mu.abs() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
|
407 |
+
# ).item()
|
408 |
+
# log_dict["latent/std"] = (
|
409 |
+
# (latent_distribution.logvar.exp() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
|
410 |
+
# ).item()
|
411 |
+
log_dict["fde/encoded"] = FDE(
|
412 |
+
unnormalizer(y_sample, offset), unnormalizer(y, offset), mask_loss
|
413 |
+
).item()
|
414 |
+
rec_loss = reconstruction_loss(y_sample, y, mask_loss)
|
415 |
+
|
416 |
+
kl_loss = latent_distribution.kl_loss(
|
417 |
+
inference_distribution,
|
418 |
+
kl_threshold,
|
419 |
+
mask_z,
|
420 |
+
)
|
421 |
+
|
422 |
+
# self.prior_distribution.to(latent_distribution.mu.device)
|
423 |
+
|
424 |
+
kl_loss_prior = latent_distribution.kl_loss(
|
425 |
+
self.prior_distribution,
|
426 |
+
kl_threshold,
|
427 |
+
mask_z,
|
428 |
+
)
|
429 |
+
|
430 |
+
sampling_loss = latent_distribution.sampling_loss()
|
431 |
+
|
432 |
+
log_dict["loss/rec"] = rec_loss.item()
|
433 |
+
log_dict["loss/kl"] = kl_loss.item()
|
434 |
+
log_dict["loss/kl_prior"] = kl_loss_prior.item()
|
435 |
+
log_dict["loss/sampling"] = sampling_loss.item()
|
436 |
+
log_dict.update(latent_distribution.log_dict("future"))
|
437 |
+
log_dict.update(inference_distribution.log_dict("inference"))
|
438 |
+
|
439 |
+
loss = (
|
440 |
+
rec_loss
|
441 |
+
+ kl_weight * kl_loss
|
442 |
+
+ self.regularization_factor * kl_loss_prior
|
443 |
+
+ sampling_loss
|
444 |
+
)
|
445 |
+
|
446 |
+
log_dict["loss/total"] = loss.item()
|
447 |
+
|
448 |
+
return loss, log_dict
|
449 |
+
|
450 |
+
def get_loss_biased(
|
451 |
+
self,
|
452 |
+
x: torch.Tensor,
|
453 |
+
mask_x: torch.Tensor,
|
454 |
+
map: torch.Tensor,
|
455 |
+
mask_map: torch.Tensor,
|
456 |
+
y: torch.Tensor,
|
457 |
+
*,
|
458 |
+
mask_loss: torch.Tensor,
|
459 |
+
offset: torch.Tensor,
|
460 |
+
unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
461 |
+
risk_level: torch.Tensor,
|
462 |
+
x_ego: torch.Tensor,
|
463 |
+
y_ego: torch.Tensor,
|
464 |
+
kl_weight: float,
|
465 |
+
kl_threshold: float,
|
466 |
+
risk_weight: float,
|
467 |
+
n_samples_risk: int,
|
468 |
+
n_samples_biased: int,
|
469 |
+
dt: float,
|
470 |
+
**kwargs,
|
471 |
+
) -> Tuple[torch.Tensor, dict]:
|
472 |
+
"""Compute and return risk-biased CVAE loss averaged over batch and sequence time steps,
|
473 |
+
along with desired loss-related metrics for logging
|
474 |
+
|
475 |
+
Args:
|
476 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
477 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
478 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
479 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
|
480 |
+
y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
481 |
+
mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
|
482 |
+
should be computed and to False where it shouldn't
|
483 |
+
offset : (batch_size, num_agents, state_dim) offset position from ego.
|
484 |
+
unnormalizer: function that takes in a trajectory and an offset and that outputs the
|
485 |
+
unnormalized trajectory
|
486 |
+
risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories
|
487 |
+
x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history
|
488 |
+
y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory
|
489 |
+
kl_weight: weight to apply to the KL loss (normal value is 1.0, larger values can be
|
490 |
+
used for disentanglement)
|
491 |
+
kl_threshold: minimum float value threshold applied to the KL loss
|
492 |
+
risk_weight: weight to apply to the risk loss (beta parameter in our document)
|
493 |
+
n_samples_risk: number of sample to use for Monte-Carlo estimation of the risk using the unbiased distribution
|
494 |
+
n_samples_biased: number of sample to use for Monte-Carlo estimation of the risk using the biased distribution
|
495 |
+
dt: time step in trajectories
|
496 |
+
|
497 |
+
Returns:
|
498 |
+
torch.Tensor: (1,) loss tensor
|
499 |
+
dict: dict that contains loss-related metrics to be logged
|
500 |
+
"""
|
501 |
+
log_dict = dict()
|
502 |
+
|
503 |
+
if not mask_loss.any():
|
504 |
+
warn("A batch is dropped because the whole loss is masked.")
|
505 |
+
return torch.zeros(1, requires_grad=True), {}
|
506 |
+
|
507 |
+
mask_z = mask_x.any(-1)
|
508 |
+
|
509 |
+
# Computing unbiased samples
|
510 |
+
n_samples_risk = max(1, n_samples_risk)
|
511 |
+
n_samples_biased = max(1, n_samples_biased)
|
512 |
+
cost = []
|
513 |
+
weights = []
|
514 |
+
pack_size = min(n_samples_risk, n_samples_biased)
|
515 |
+
with torch.no_grad():
|
516 |
+
encoded_map = self.map_encoder(map, mask_map)
|
517 |
+
mask_map = mask_map.any(-1)
|
518 |
+
encoded_absolute = self.absolute_encoder(offset)
|
519 |
+
|
520 |
+
inference_distribution = self.inference_encoder(
|
521 |
+
x,
|
522 |
+
mask_x,
|
523 |
+
encoded_absolute,
|
524 |
+
encoded_map,
|
525 |
+
mask_map,
|
526 |
+
)
|
527 |
+
for _ in range(n_samples_risk // pack_size):
|
528 |
+
z_samples, w = inference_distribution.sample(
|
529 |
+
n_samples=pack_size,
|
530 |
+
)
|
531 |
+
|
532 |
+
y_samples = self.decoder(
|
533 |
+
z_samples=z_samples,
|
534 |
+
mask_z=mask_z,
|
535 |
+
x=x,
|
536 |
+
mask_x=mask_x,
|
537 |
+
encoded_absolute=encoded_absolute,
|
538 |
+
encoded_map=encoded_map,
|
539 |
+
mask_map=mask_map,
|
540 |
+
offset=offset,
|
541 |
+
)
|
542 |
+
|
543 |
+
mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=pack_size)
|
544 |
+
# Computing unbiased cost
|
545 |
+
cost.append(
|
546 |
+
get_cost(
|
547 |
+
self._cost,
|
548 |
+
x,
|
549 |
+
y_samples,
|
550 |
+
offset,
|
551 |
+
x_ego,
|
552 |
+
y_ego,
|
553 |
+
dt,
|
554 |
+
unnormalizer,
|
555 |
+
mask_loss_samples,
|
556 |
+
)
|
557 |
+
)
|
558 |
+
weights.append(w)
|
559 |
+
|
560 |
+
cost = torch.cat(cost, 2)
|
561 |
+
weights = torch.cat(weights, 2)
|
562 |
+
risk_cost = self._risk(risk_level, cost, weights)
|
563 |
+
|
564 |
+
log_dict["fde/prior"] = FDE(
|
565 |
+
unnormalizer(y_samples, offset),
|
566 |
+
unnormalizer(y, offset).unsqueeze(-3),
|
567 |
+
mask_loss_samples,
|
568 |
+
).item()
|
569 |
+
|
570 |
+
mask_cost_samples = repeat(mask_z, "b a -> b a s", s=n_samples_risk)
|
571 |
+
mean_cost = (cost * mask_cost_samples.float() * weights).sum(2) / (
|
572 |
+
(mask_cost_samples.float() * weights).sum(2).clamp_min(1)
|
573 |
+
)
|
574 |
+
log_dict["cost/mean"] = (
|
575 |
+
(mean_cost * mask_loss.any(-1).float()).sum()
|
576 |
+
/ (mask_loss.any(-1).float().sum())
|
577 |
+
).item()
|
578 |
+
|
579 |
+
# Computing biased latent parameters
|
580 |
+
biased_distribution = self.biased_encoder(
|
581 |
+
x,
|
582 |
+
mask_x,
|
583 |
+
encoded_absolute.detach(),
|
584 |
+
encoded_map.detach(),
|
585 |
+
mask_map,
|
586 |
+
risk_level=risk_level,
|
587 |
+
x_ego=x_ego,
|
588 |
+
y_ego=y_ego,
|
589 |
+
offset=offset,
|
590 |
+
)
|
591 |
+
biased_distribution = inference_distribution.average(
|
592 |
+
biased_distribution, risk_level.unsqueeze(-1)
|
593 |
+
)
|
594 |
+
|
595 |
+
# sum_mask_z = mask_z.float().sum().clamp_min(1)* biased_distribution.mu.shape[-1]
|
596 |
+
# log_dict["latent/abs_mean_biased"] = (
|
597 |
+
# (biased_distribution.mu.abs() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
|
598 |
+
# ).item()
|
599 |
+
# log_dict["latent/var_biased"] = (
|
600 |
+
# (biased_distribution.logvar.exp() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
|
601 |
+
# ).item()
|
602 |
+
|
603 |
+
# Computing biased samples
|
604 |
+
z_biased_samples, weights = biased_distribution.sample(
|
605 |
+
n_samples=n_samples_biased
|
606 |
+
)
|
607 |
+
mask_z_samples = repeat(mask_z, "b a -> b a s ()", s=n_samples_biased)
|
608 |
+
log_dict["latent/abs_samples_biased"] = (
|
609 |
+
(z_biased_samples.abs() * mask_z_samples.float()).sum()
|
610 |
+
/ (mask_z_samples.float().sum())
|
611 |
+
).item()
|
612 |
+
|
613 |
+
y_biased_samples = self.decoder(
|
614 |
+
z_samples=z_biased_samples,
|
615 |
+
mask_z=mask_z,
|
616 |
+
x=x,
|
617 |
+
mask_x=mask_x,
|
618 |
+
encoded_absolute=encoded_absolute,
|
619 |
+
encoded_map=encoded_map,
|
620 |
+
mask_map=mask_map,
|
621 |
+
offset=offset,
|
622 |
+
)
|
623 |
+
|
624 |
+
log_dict["fde/prior_biased"] = FDE(
|
625 |
+
unnormalizer(y_biased_samples, offset),
|
626 |
+
unnormalizer(y, offset).unsqueeze(2),
|
627 |
+
mask_loss=mask_loss_samples,
|
628 |
+
).item()
|
629 |
+
|
630 |
+
# Computing biased cost
|
631 |
+
biased_cost = get_cost(
|
632 |
+
self._cost,
|
633 |
+
x,
|
634 |
+
y_biased_samples,
|
635 |
+
offset,
|
636 |
+
x_ego,
|
637 |
+
y_ego,
|
638 |
+
dt,
|
639 |
+
unnormalizer,
|
640 |
+
mask_loss_samples,
|
641 |
+
)
|
642 |
+
mask_cost_samples = mask_z_samples.squeeze(-1)
|
643 |
+
mean_biased_cost = (biased_cost * mask_cost_samples.float() * weights).sum(
|
644 |
+
2
|
645 |
+
) / ((mask_cost_samples.float() * weights).sum(2).clamp_min(1))
|
646 |
+
log_dict["cost/mean_biased"] = (
|
647 |
+
(mean_biased_cost * mask_loss.any(-1).float()).sum()
|
648 |
+
/ (mask_loss.any(-1).float().sum())
|
649 |
+
).item()
|
650 |
+
|
651 |
+
log_dict["cost/risk"] = (
|
652 |
+
(risk_cost * mask_loss.any(-1).float()).sum()
|
653 |
+
/ (mask_loss.any(-1).float().sum())
|
654 |
+
).item()
|
655 |
+
|
656 |
+
# Computing loss between risk and biased cost
|
657 |
+
risk_loss = risk_loss_function(
|
658 |
+
mean_biased_cost,
|
659 |
+
risk_cost.detach(),
|
660 |
+
mask_loss.any(-1),
|
661 |
+
self.risk_assymetry_factor,
|
662 |
+
)
|
663 |
+
log_dict["loss/risk"] = risk_loss.item()
|
664 |
+
|
665 |
+
# Computing KL loss between prior and biased latent
|
666 |
+
kl_loss = inference_distribution.kl_loss(
|
667 |
+
biased_distribution,
|
668 |
+
kl_threshold,
|
669 |
+
mask_z=mask_z,
|
670 |
+
)
|
671 |
+
log_dict["loss/kl"] = kl_loss.item()
|
672 |
+
|
673 |
+
loss = risk_weight * risk_loss + kl_weight * kl_loss
|
674 |
+
log_dict["loss/total"] = loss.item()
|
675 |
+
|
676 |
+
log_dict["loss/risk_weight"] = risk_weight
|
677 |
+
log_dict.update(inference_distribution.log_dict("inference"))
|
678 |
+
log_dict.update(biased_distribution.log_dict("biased"))
|
679 |
+
|
680 |
+
return loss, log_dict
|
681 |
+
|
682 |
+
def get_prediction_accuracy(
|
683 |
+
self,
|
684 |
+
x: torch.Tensor,
|
685 |
+
mask_x: torch.Tensor,
|
686 |
+
map: torch.Tensor,
|
687 |
+
mask_map: torch.Tensor,
|
688 |
+
y: torch.Tensor,
|
689 |
+
mask_loss: torch.Tensor,
|
690 |
+
x_ego: torch.Tensor,
|
691 |
+
y_ego: torch.Tensor,
|
692 |
+
offset: torch.Tensor,
|
693 |
+
unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
694 |
+
risk_level: torch.Tensor,
|
695 |
+
num_samples_min_fde: int = 0,
|
696 |
+
) -> dict:
|
697 |
+
"""
|
698 |
+
A function that calls the predict method and returns a dict that contains prediction
|
699 |
+
metrics, which measure accuracy with respect to ground-truth future trajectory y
|
700 |
+
Args:
|
701 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
702 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
703 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
704 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
|
705 |
+
y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
706 |
+
mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
|
707 |
+
should be computed and to False where it shouldn't
|
708 |
+
x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history
|
709 |
+
y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory
|
710 |
+
offset: (batch_size, num_agents, state_dim) offset position from ego
|
711 |
+
|
712 |
+
unnormalizer: function that takes in a trajectory and an offset and that outputs the
|
713 |
+
unnormalized trajectory
|
714 |
+
risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories
|
715 |
+
num_samples_min_fde: number of samples to use when computing the minimum final displacement error
|
716 |
+
Returns:
|
717 |
+
dict: dict that contains prediction-related metrics to be logged
|
718 |
+
"""
|
719 |
+
log_dict = dict()
|
720 |
+
with torch.no_grad():
|
721 |
+
batch_size = x.shape[0]
|
722 |
+
beg = 0
|
723 |
+
y_predict = []
|
724 |
+
|
725 |
+
# Limit the batch size so the num_samples_min_fde value does not impact the memory usage
|
726 |
+
for i in range(batch_size // num_samples_min_fde + 1):
|
727 |
+
sub_batch_size = num_samples_min_fde
|
728 |
+
end = beg + sub_batch_size
|
729 |
+
|
730 |
+
y_predict.append(
|
731 |
+
unnormalizer(
|
732 |
+
self.forward(
|
733 |
+
x=x[beg:end],
|
734 |
+
mask_x=mask_x[beg:end],
|
735 |
+
map=map[beg:end],
|
736 |
+
mask_map=mask_map[beg:end],
|
737 |
+
offset=offset[beg:end],
|
738 |
+
x_ego=x_ego[beg:end],
|
739 |
+
y_ego=y_ego[beg:end],
|
740 |
+
risk_level=None,
|
741 |
+
n_samples=num_samples_min_fde,
|
742 |
+
)[0],
|
743 |
+
offset[beg:end],
|
744 |
+
)
|
745 |
+
)
|
746 |
+
beg = end
|
747 |
+
if beg >= batch_size:
|
748 |
+
break
|
749 |
+
|
750 |
+
# Limit the batch size so the num_samples_min_fde value does not impact the memory usage
|
751 |
+
if risk_level is not None:
|
752 |
+
y_predict_biased = []
|
753 |
+
beg = 0
|
754 |
+
for i in range(batch_size // num_samples_min_fde + 1):
|
755 |
+
sub_batch_size = num_samples_min_fde
|
756 |
+
end = beg + sub_batch_size
|
757 |
+
y_predict_biased.append(
|
758 |
+
unnormalizer(
|
759 |
+
self.forward(
|
760 |
+
x=x[beg:end],
|
761 |
+
mask_x=mask_x[beg:end],
|
762 |
+
map=map[beg:end],
|
763 |
+
mask_map=mask_map[beg:end],
|
764 |
+
offset=offset[beg:end],
|
765 |
+
x_ego=x_ego[beg:end],
|
766 |
+
y_ego=y_ego[beg:end],
|
767 |
+
risk_level=risk_level[beg:end],
|
768 |
+
n_samples=num_samples_min_fde,
|
769 |
+
)[0],
|
770 |
+
offset[beg:end],
|
771 |
+
)
|
772 |
+
)
|
773 |
+
beg = end
|
774 |
+
if beg >= batch_size:
|
775 |
+
break
|
776 |
+
y_predict_biased = torch.cat(y_predict_biased, 0)
|
777 |
+
if num_samples_min_fde > 0:
|
778 |
+
repeated_mask_loss = repeat(
|
779 |
+
mask_loss, "b a t -> b a samples t", samples=num_samples_min_fde
|
780 |
+
)
|
781 |
+
log_dict["fde/prior_biased"] = FDE(
|
782 |
+
y_predict_biased, y.unsqueeze(-3), mask_loss=repeated_mask_loss
|
783 |
+
).item()
|
784 |
+
log_dict["minfde/prior_biased"] = minFDE(
|
785 |
+
y_predict_biased, y.unsqueeze(-3), mask_loss=repeated_mask_loss
|
786 |
+
).item()
|
787 |
+
else:
|
788 |
+
log_dict["fde/prior_biased"] = FDE(
|
789 |
+
y_predict_biased, y, mask_loss=mask_loss
|
790 |
+
).item()
|
791 |
+
|
792 |
+
y_predict = torch.cat(y_predict, 0)
|
793 |
+
y_unnormalized = unnormalizer(y, offset)
|
794 |
+
if num_samples_min_fde > 0:
|
795 |
+
repeated_mask_loss = repeat(
|
796 |
+
mask_loss, "b a t -> b a samples t", samples=num_samples_min_fde
|
797 |
+
)
|
798 |
+
log_dict["fde/prior"] = FDE(
|
799 |
+
y_predict, y_unnormalized.unsqueeze(-3), mask_loss=repeated_mask_loss
|
800 |
+
).item()
|
801 |
+
log_dict["minfde/prior"] = minFDE(
|
802 |
+
y_predict, y_unnormalized.unsqueeze(-3), mask_loss=repeated_mask_loss
|
803 |
+
).item()
|
804 |
+
else:
|
805 |
+
log_dict["fde/prior"] = FDE(
|
806 |
+
y_predict, y_unnormalized, mask_loss=mask_loss
|
807 |
+
).item()
|
808 |
+
return log_dict
|
809 |
+
|
810 |
+
|
811 |
+
def cvae_factory(
|
812 |
+
params: CVAEParams,
|
813 |
+
cost_function: BaseCostTorch,
|
814 |
+
risk_estimator: AbstractMonteCarloRiskEstimator,
|
815 |
+
training_mode: str = "cvae",
|
816 |
+
):
|
817 |
+
"""Biased CVAE with a biased MLP encoder and an MLP decoder
|
818 |
+
Args:
|
819 |
+
params: dataclass defining the necessary parameters
|
820 |
+
cost_function: cost function used to compute the risk objective
|
821 |
+
risk_estimator: risk estimator used to compute the risk objective
|
822 |
+
training_mode: "inference", "cvae" or "bias" set what is the training mode
|
823 |
+
latent_distribution: "gaussian" or "quantized" set the latent distribution
|
824 |
+
"""
|
825 |
+
|
826 |
+
absolute_encoder_nn = MLP(
|
827 |
+
params.dynamic_state_dim,
|
828 |
+
params.hidden_dim,
|
829 |
+
params.hidden_dim,
|
830 |
+
params.num_hidden_layers,
|
831 |
+
params.is_mlp_residual,
|
832 |
+
)
|
833 |
+
|
834 |
+
map_encoder_nn = MapEncoderNN(params)
|
835 |
+
|
836 |
+
if params.latent_distribution == "gaussian":
|
837 |
+
latent_distribution_creator = GaussianLatentDistribution
|
838 |
+
prior_distribution = GaussianLatentDistribution(
|
839 |
+
torch.zeros(1, 1, 2 * params.latent_dim)
|
840 |
+
)
|
841 |
+
future_encoder_latent_dim = 2 * params.latent_dim
|
842 |
+
inference_encoder_latent_dim = 2 * params.latent_dim
|
843 |
+
biased_encoder_latent_dim = 2 * params.latent_dim
|
844 |
+
elif params.latent_distribution == "quantized":
|
845 |
+
latent_distribution_creator = QuantizedDistributionCreator(
|
846 |
+
params.latent_dim, params.num_vq
|
847 |
+
)
|
848 |
+
prior_distribution = latent_distribution_creator(
|
849 |
+
torch.zeros(1, 1, params.num_vq)
|
850 |
+
)
|
851 |
+
future_encoder_latent_dim = params.latent_dim
|
852 |
+
inference_encoder_latent_dim = params.num_vq
|
853 |
+
biased_encoder_latent_dim = params.num_vq
|
854 |
+
|
855 |
+
biased_encoder_nn = BiasedEncoderNN(
|
856 |
+
params,
|
857 |
+
biased_encoder_latent_dim,
|
858 |
+
num_steps=params.num_steps,
|
859 |
+
)
|
860 |
+
biased_encoder = CVAEEncoder(
|
861 |
+
biased_encoder_nn, latent_distribution_creator=latent_distribution_creator
|
862 |
+
)
|
863 |
+
|
864 |
+
future_encoder_nn = FutureEncoderNN(
|
865 |
+
params, future_encoder_latent_dim, params.num_steps + params.num_steps_future
|
866 |
+
)
|
867 |
+
future_encoder = CVAEEncoder(
|
868 |
+
future_encoder_nn, latent_distribution_creator=latent_distribution_creator
|
869 |
+
)
|
870 |
+
|
871 |
+
inference_encoder_nn = InferenceEncoderNN(
|
872 |
+
params, inference_encoder_latent_dim, params.num_steps
|
873 |
+
)
|
874 |
+
inference_encoder = CVAEEncoder(
|
875 |
+
inference_encoder_nn, latent_distribution_creator=latent_distribution_creator
|
876 |
+
)
|
877 |
+
|
878 |
+
decoder_nn = DecoderNN(params)
|
879 |
+
decoder = CVAEAccelerationDecoder(decoder_nn)
|
880 |
+
# decoder = CVAEParametrizedDecoder(decoder_nn)
|
881 |
+
|
882 |
+
if training_mode == "inference":
|
883 |
+
cvae = InferenceBiasedCVAE(
|
884 |
+
absolute_encoder_nn,
|
885 |
+
map_encoder_nn,
|
886 |
+
biased_encoder,
|
887 |
+
inference_encoder,
|
888 |
+
decoder,
|
889 |
+
prior_distribution=prior_distribution,
|
890 |
+
)
|
891 |
+
cvae.eval()
|
892 |
+
return cvae
|
893 |
+
else:
|
894 |
+
return TrainingBiasedCVAE(
|
895 |
+
absolute_encoder_nn,
|
896 |
+
map_encoder_nn,
|
897 |
+
biased_encoder,
|
898 |
+
inference_encoder,
|
899 |
+
decoder,
|
900 |
+
future_encoder=future_encoder,
|
901 |
+
cost_function=cost_function,
|
902 |
+
risk_estimator=risk_estimator,
|
903 |
+
training_mode=training_mode,
|
904 |
+
latent_regularization=params.latent_regularization,
|
905 |
+
risk_assymetry_factor=params.risk_assymetry_factor,
|
906 |
+
prior_distribution=prior_distribution,
|
907 |
+
)
|
risk_biased/models/context_gating.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from risk_biased.models.mlp import MLP
|
4 |
+
|
5 |
+
|
6 |
+
def pool(x, dim):
|
7 |
+
x, _ = x.max(dim)
|
8 |
+
return x
|
9 |
+
|
10 |
+
|
11 |
+
class ContextGating(nn.Module):
|
12 |
+
"""Inspired by Multi-Path++ https://arxiv.org/pdf/2111.14973v3.pdf (but not the same)
|
13 |
+
|
14 |
+
Args:
|
15 |
+
d_model: input dimension of the model
|
16 |
+
d: hidden dimension of the model
|
17 |
+
num_layers: number of layers of the MLP blocks
|
18 |
+
is_mlp_residual: whether to use residual connections in the MLP blocks
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, d_model, d, num_layers, is_mlp_residual):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.w_s = MLP(d_model, d, int((d_model + d) / 2), num_layers, is_mlp_residual)
|
25 |
+
self.w_c_cross = MLP(
|
26 |
+
d_model, d, int((d_model + d) / 2), num_layers, is_mlp_residual
|
27 |
+
)
|
28 |
+
self.w_c_global = MLP(d, d, d, num_layers, is_mlp_residual)
|
29 |
+
|
30 |
+
self.output_layer = nn.Linear(d, d_model)
|
31 |
+
|
32 |
+
def forward(self, s, c_cross, c_global):
|
33 |
+
"""context gating forward function
|
34 |
+
|
35 |
+
Args:
|
36 |
+
|
37 |
+
s: (batch, agents, features) tensor of agent encoded states
|
38 |
+
c_cross: (batch, objects, features) tensor of objects encoded states
|
39 |
+
c_global: (batch, d) tensor of global context
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
|
43 |
+
s: (batch, agents, features) updated tensor of agent encoded states
|
44 |
+
c_global: updated tensor of global context
|
45 |
+
|
46 |
+
"""
|
47 |
+
s = self.w_s(s)
|
48 |
+
c_cross = self.w_c_cross(c_cross)
|
49 |
+
c_global = pool(c_cross, -2) * self.w_c_global(c_global)
|
50 |
+
# b: batch, a: agents, k: features
|
51 |
+
s = torch.einsum("bak,bk->bak", [s, c_global])
|
52 |
+
s = self.output_layer(s)
|
53 |
+
return s, c_global
|
risk_biased/models/cvae_decoder.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange, repeat
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from risk_biased.models.cvae_params import CVAEParams
|
6 |
+
from risk_biased.models.nn_blocks import (
|
7 |
+
MCG,
|
8 |
+
MAB,
|
9 |
+
MHB,
|
10 |
+
SequenceDecoderLSTM,
|
11 |
+
SequenceDecoderMLP,
|
12 |
+
SequenceEncoderLSTM,
|
13 |
+
SequenceEncoderMLP,
|
14 |
+
SequenceEncoderMaskedLSTM,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class DecoderNN(nn.Module):
|
19 |
+
"""Decoder neural network that decodes input tensors into a single output tensor.
|
20 |
+
It contains an interaction layer that (re-)compute the interactions between the agents in the scene.
|
21 |
+
This implies that a given latent sample for one agent will be affecting the predictions of the othe agents too.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
params: dataclass defining the necessary parameters
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
params: CVAEParams,
|
31 |
+
) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.dt = params.dt
|
34 |
+
self.state_dim = params.state_dim
|
35 |
+
self.dynamic_state_dim = params.dynamic_state_dim
|
36 |
+
self.hidden_dim = params.hidden_dim
|
37 |
+
self.num_steps_future = params.num_steps_future
|
38 |
+
self.latent_dim = params.latent_dim
|
39 |
+
|
40 |
+
if params.sequence_encoder_type == "MLP":
|
41 |
+
self._agent_encoder_past = SequenceEncoderMLP(
|
42 |
+
params.state_dim,
|
43 |
+
params.hidden_dim,
|
44 |
+
params.num_hidden_layers,
|
45 |
+
params.num_steps,
|
46 |
+
params.is_mlp_residual,
|
47 |
+
)
|
48 |
+
elif params.sequence_encoder_type == "LSTM":
|
49 |
+
self._agent_encoder_past = SequenceEncoderLSTM(
|
50 |
+
params.state_dim, params.hidden_dim
|
51 |
+
)
|
52 |
+
elif params.sequence_encoder_type == "maskedLSTM":
|
53 |
+
self._agent_encoder_past = SequenceEncoderMaskedLSTM(
|
54 |
+
params.state_dim, params.hidden_dim
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
raise RuntimeError(
|
58 |
+
f"Got sequence encoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' "
|
59 |
+
)
|
60 |
+
|
61 |
+
self._combine_z_past = nn.Linear(
|
62 |
+
params.hidden_dim + params.latent_dim, params.hidden_dim
|
63 |
+
)
|
64 |
+
|
65 |
+
if params.interaction_type == "Attention" or params.interaction_type == "MAB":
|
66 |
+
self._interaction = MAB(
|
67 |
+
params.hidden_dim, params.num_attention_heads, params.num_blocks
|
68 |
+
)
|
69 |
+
elif (
|
70 |
+
params.interaction_type == "ContextGating"
|
71 |
+
or params.interaction_type == "MCG"
|
72 |
+
):
|
73 |
+
self._interaction = MCG(
|
74 |
+
params.hidden_dim,
|
75 |
+
params.mcg_dim_expansion,
|
76 |
+
params.mcg_num_layers,
|
77 |
+
params.num_blocks,
|
78 |
+
params.is_mlp_residual,
|
79 |
+
)
|
80 |
+
elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB":
|
81 |
+
self._interaction = MHB(
|
82 |
+
params.hidden_dim,
|
83 |
+
params.num_attention_heads,
|
84 |
+
params.mcg_dim_expansion,
|
85 |
+
params.mcg_num_layers,
|
86 |
+
params.num_blocks,
|
87 |
+
params.is_mlp_residual,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
self._interaction = lambda x, *args, **kwargs: x
|
91 |
+
|
92 |
+
if params.sequence_decoder_type == "MLP":
|
93 |
+
self._decoder = SequenceDecoderMLP(
|
94 |
+
params.hidden_dim,
|
95 |
+
params.num_hidden_layers,
|
96 |
+
params.num_steps_future,
|
97 |
+
params.is_mlp_residual,
|
98 |
+
)
|
99 |
+
elif params.sequence_decoder_type == "LSTM":
|
100 |
+
self._decoder = SequenceDecoderLSTM(params.hidden_dim)
|
101 |
+
elif params.sequence_decoder_type == "maskedLSTM":
|
102 |
+
self._decoder = SequenceDecoderLSTM(params.hidden_dim)
|
103 |
+
else:
|
104 |
+
raise RuntimeError(
|
105 |
+
f"Got sequence decoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' "
|
106 |
+
)
|
107 |
+
|
108 |
+
def forward(
|
109 |
+
self,
|
110 |
+
z_samples: torch.Tensor,
|
111 |
+
mask_z: torch.Tensor,
|
112 |
+
x: torch.Tensor,
|
113 |
+
mask_x: torch.Tensor,
|
114 |
+
encoded_absolute: torch.Tensor,
|
115 |
+
encoded_map: torch.Tensor,
|
116 |
+
mask_map: torch.Tensor,
|
117 |
+
) -> torch.Tensor:
|
118 |
+
"""Forward function that decodes input tensors into an output tensor of size
|
119 |
+
(batch_size, num_agents, (n_samples), num_steps_future, state_dim)
|
120 |
+
|
121 |
+
Args:
|
122 |
+
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
|
123 |
+
mask_z: (batch_size, num_agents) tensor of bool mask
|
124 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
|
125 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
126 |
+
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
127 |
+
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
|
128 |
+
mask_map: (batch_size, num_objects) tensor of bool mask
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
(batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor
|
132 |
+
"""
|
133 |
+
|
134 |
+
encoded_x = self._agent_encoder_past(x, mask_x)
|
135 |
+
squeeze_output_sample_dim = False
|
136 |
+
if z_samples.ndim == 3:
|
137 |
+
batch_size, num_agents, latent_dim = z_samples.shape
|
138 |
+
num_samples = 1
|
139 |
+
z_samples = rearrange(z_samples, "b a l -> b a () l")
|
140 |
+
squeeze_output_sample_dim = True
|
141 |
+
else:
|
142 |
+
batch_size, num_agents, num_samples, latent_dim = z_samples.shape
|
143 |
+
mask_z = repeat(mask_z, "b a -> (b s) a", s=num_samples)
|
144 |
+
mask_map = repeat(mask_map, "b o -> (b s) o", s=num_samples)
|
145 |
+
encoded_x = repeat(encoded_x, "b a l -> (b s) a l", s=num_samples)
|
146 |
+
encoded_absolute = repeat(
|
147 |
+
encoded_absolute, "b a l -> (b s) a l", s=num_samples
|
148 |
+
)
|
149 |
+
encoded_map = repeat(encoded_map, "b o l -> (b s) o l", s=num_samples)
|
150 |
+
|
151 |
+
z_samples = rearrange(z_samples, "b a s l -> (b s) a l")
|
152 |
+
|
153 |
+
h = self._combine_z_past(torch.cat([z_samples, encoded_x], dim=-1))
|
154 |
+
|
155 |
+
h = self._interaction(h, mask_z, encoded_absolute, encoded_map, mask_map)
|
156 |
+
|
157 |
+
h = self._decoder(h, self.num_steps_future)
|
158 |
+
|
159 |
+
if not squeeze_output_sample_dim:
|
160 |
+
h = rearrange(h, "(b s) a t l -> b a s t l", b=batch_size, s=num_samples)
|
161 |
+
|
162 |
+
return h
|
163 |
+
|
164 |
+
|
165 |
+
class CVAEAccelerationDecoder(nn.Module):
|
166 |
+
"""Decoder architecture for conditional variational autoencoder
|
167 |
+
|
168 |
+
Args:
|
169 |
+
model: decoder neural network that transforms input tensors to an output sequence
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
model: nn.Module,
|
175 |
+
) -> None:
|
176 |
+
super().__init__()
|
177 |
+
self._model = model
|
178 |
+
self._output_layer = nn.Linear(model.hidden_dim, 2)
|
179 |
+
|
180 |
+
def forward(
|
181 |
+
self,
|
182 |
+
z_samples: torch.Tensor,
|
183 |
+
mask_z: torch.Tensor,
|
184 |
+
x: torch.Tensor,
|
185 |
+
mask_x: torch.Tensor,
|
186 |
+
encoded_absolute: torch.Tensor,
|
187 |
+
encoded_map: torch.Tensor,
|
188 |
+
mask_map: torch.Tensor,
|
189 |
+
offset: torch.Tensor,
|
190 |
+
) -> torch.Tensor:
|
191 |
+
"""Forward function that decodes input tensors into an output tensor of size
|
192 |
+
(batch_size, num_agents, (n_samples), num_steps_future, state_dim=5)
|
193 |
+
It first predicts accelerations that are doubly integrated to produce the output
|
194 |
+
state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y)
|
195 |
+
|
196 |
+
Args:
|
197 |
+
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
|
198 |
+
mask_z: (batch_size, num_agents) tensor of bool mask
|
199 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
|
200 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
201 |
+
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
202 |
+
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
|
203 |
+
mask_map: (batch_size, num_objects) tensor of bool mask
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
(batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension
|
207 |
+
does not exist if z_samples is a 2D tensor.
|
208 |
+
"""
|
209 |
+
|
210 |
+
h = self._model(
|
211 |
+
z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map
|
212 |
+
)
|
213 |
+
h = self._output_layer(h)
|
214 |
+
|
215 |
+
dt = self._model.dt
|
216 |
+
initial_position = x[..., -1:, :2].clone()
|
217 |
+
# If shape is 5 it should be (x, y, angle, vx, vy)
|
218 |
+
if offset.shape[-1] == 5:
|
219 |
+
initial_velocity = offset[..., 3:5].clone().unsqueeze(-2)
|
220 |
+
# else if shape is 4 it should be (x, y, vx, vy)
|
221 |
+
elif offset.shape[-1] == 4:
|
222 |
+
initial_velocity = offset[..., 2:4].clone().unsqueeze(-2)
|
223 |
+
elif x.shape[-1] == 5:
|
224 |
+
initial_velocity = x[..., -1:, 3:5].clone()
|
225 |
+
elif x.shape[-1] == 4:
|
226 |
+
initial_velocity = x[..., -1:, 2:4].clone()
|
227 |
+
else:
|
228 |
+
initial_velocity = (x[..., -1:, :] - x[..., -2:-1, :]) / dt
|
229 |
+
|
230 |
+
output = torch.zeros(
|
231 |
+
(*h.shape[:-1], self._model.dynamic_state_dim), device=h.device
|
232 |
+
)
|
233 |
+
# There might be a sample dimension in the output tensor, then adapt the shape of initial position and velocity
|
234 |
+
if output.ndim == 5:
|
235 |
+
initial_position = initial_position.unsqueeze(-3)
|
236 |
+
initial_velocity = initial_velocity.unsqueeze(-3)
|
237 |
+
|
238 |
+
if self._model.dynamic_state_dim == 5:
|
239 |
+
output[..., 3:5] = h.cumsum(-2) * dt
|
240 |
+
output[..., :2] = (output[..., 3:5].clone() + initial_velocity).cumsum(
|
241 |
+
-2
|
242 |
+
) * dt + initial_position
|
243 |
+
output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone())
|
244 |
+
elif self._model.dynamic_state_dim == 4:
|
245 |
+
output[..., 2:4] = h.cumsum(-2) * dt
|
246 |
+
output[..., :2] = (output[..., 2:4].clone() + initial_velocity).cumsum(
|
247 |
+
-2
|
248 |
+
) * dt + initial_position
|
249 |
+
else:
|
250 |
+
velocity = h.cumsum(-2) * dt
|
251 |
+
output = (velocity.clone() + initial_velocity).cumsum(
|
252 |
+
-2
|
253 |
+
) * dt + initial_position
|
254 |
+
return output
|
255 |
+
|
256 |
+
|
257 |
+
class CVAEParametrizedDecoder(nn.Module):
|
258 |
+
"""Decoder architecture for conditional variational autoencoder
|
259 |
+
|
260 |
+
Args:
|
261 |
+
model: decoder neural network that transforms input tensors to an output sequence
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
model: nn.Module,
|
267 |
+
) -> None:
|
268 |
+
super().__init__()
|
269 |
+
self._model = model
|
270 |
+
self._order = 3
|
271 |
+
self._output_layer = nn.Linear(
|
272 |
+
model.hidden_dim * model.num_steps_future,
|
273 |
+
2 * self._order + model.num_steps_future,
|
274 |
+
)
|
275 |
+
|
276 |
+
def polynomial(self, x: torch.Tensor, params: torch.Tensor):
|
277 |
+
"""Polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and
|
278 |
+
a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future)
|
279 |
+
"""
|
280 |
+
h = x.clone()
|
281 |
+
squeeze = False
|
282 |
+
if h.ndim == 3:
|
283 |
+
h = h.unsqueeze(2)
|
284 |
+
params = params.unsqueeze(2)
|
285 |
+
squeeze = True
|
286 |
+
h = repeat(
|
287 |
+
h,
|
288 |
+
"batch agents samples sequence -> batch agents samples sequence two order",
|
289 |
+
order=self._order,
|
290 |
+
two=2,
|
291 |
+
).cumprod(-1)
|
292 |
+
h = h * params.view(*params.shape[:-1], 1, 2, self._order)
|
293 |
+
h = h.sum(-1)
|
294 |
+
if squeeze:
|
295 |
+
h = h.squeeze(2)
|
296 |
+
return h
|
297 |
+
|
298 |
+
def dpolynomial(self, x: torch.Tensor, params: torch.Tensor):
|
299 |
+
"""Derivative of the polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and
|
300 |
+
a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future)
|
301 |
+
"""
|
302 |
+
h = x.clone()
|
303 |
+
squeeze = False
|
304 |
+
if h.ndim == 3:
|
305 |
+
h = h.unsqueeze(2)
|
306 |
+
params = params.unsqueeze(2)
|
307 |
+
squeeze = True
|
308 |
+
h = repeat(
|
309 |
+
h,
|
310 |
+
"batch agents samples sequence -> batch agents samples sequence two order",
|
311 |
+
order=self._order - 1,
|
312 |
+
two=2,
|
313 |
+
)
|
314 |
+
h = torch.cat((torch.ones_like(h[..., :1]), h.cumprod(-1)), -1)
|
315 |
+
h = h * params.view(*params.shape[:-1], 1, 2, self._order)
|
316 |
+
h = h * torch.arange(self._order).view(*([1] * params.ndim), -1).to(x.device)
|
317 |
+
h = h.sum(-1)
|
318 |
+
if squeeze:
|
319 |
+
h = h.squeeze(2)
|
320 |
+
return h
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self,
|
324 |
+
z_samples: torch.Tensor,
|
325 |
+
mask_z: torch.Tensor,
|
326 |
+
x: torch.Tensor,
|
327 |
+
mask_x: torch.Tensor,
|
328 |
+
encoded_absolute: torch.Tensor,
|
329 |
+
encoded_map: torch.Tensor,
|
330 |
+
mask_map: torch.Tensor,
|
331 |
+
offset: torch.Tensor,
|
332 |
+
) -> torch.Tensor:
|
333 |
+
"""Forward function that decodes input tensors into an output tensor of size
|
334 |
+
(batch_size, num_agents, (n_samples), num_steps_future, state_dim=5)
|
335 |
+
It first predicts accelerations that are doubly integrated to produce the output
|
336 |
+
state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y)
|
337 |
+
|
338 |
+
Args:
|
339 |
+
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
|
340 |
+
mask_z: (batch_size, num_agents) tensor of bool mask
|
341 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
|
342 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
343 |
+
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
344 |
+
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
|
345 |
+
mask_map: (batch_size, num_objects) tensor of bool mask
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
(batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension
|
349 |
+
does not exist if z_samples is a 2D tensor.
|
350 |
+
"""
|
351 |
+
|
352 |
+
squeeze_output_sample_dim = z_samples.ndim == 3
|
353 |
+
batch_size = z_samples.shape[0]
|
354 |
+
|
355 |
+
h = self._model(
|
356 |
+
z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map
|
357 |
+
)
|
358 |
+
if squeeze_output_sample_dim:
|
359 |
+
h = rearrange(
|
360 |
+
h, "batch agents sequence features -> batch agents (sequence features)"
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
h = rearrange(
|
364 |
+
h,
|
365 |
+
"(batch samples) agents sequence features -> batch agents samples (sequence features)",
|
366 |
+
batch=batch_size,
|
367 |
+
)
|
368 |
+
h = self._output_layer(h)
|
369 |
+
|
370 |
+
output = torch.zeros(
|
371 |
+
(
|
372 |
+
*h.shape[:-1],
|
373 |
+
self._model.num_steps_future,
|
374 |
+
self._model.dynamic_state_dim,
|
375 |
+
),
|
376 |
+
device=h.device,
|
377 |
+
)
|
378 |
+
params = h[..., : 2 * self._order]
|
379 |
+
dldt = torch.relu(h[..., 2 * self._order :])
|
380 |
+
distance = dldt.cumsum(-2)
|
381 |
+
output[..., :2] = self.polynomial(distance, params)
|
382 |
+
if self._model.dynamic_state_dim == 5:
|
383 |
+
output[..., 3:5] = dldt * self.dpolynomial(distance, params)
|
384 |
+
output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone())
|
385 |
+
elif self._model.dynamic_state_dim == 4:
|
386 |
+
output[..., 2:4] = dldt * self.dpolynomial(distance, params)
|
387 |
+
|
388 |
+
return output
|
risk_biased/models/cvae_encoders.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from risk_biased.models.cvae_params import CVAEParams
|
8 |
+
from risk_biased.models.nn_blocks import (
|
9 |
+
MCG,
|
10 |
+
MAB,
|
11 |
+
MHB,
|
12 |
+
SequenceEncoderLSTM,
|
13 |
+
SequenceEncoderMLP,
|
14 |
+
SequenceEncoderMaskedLSTM,
|
15 |
+
)
|
16 |
+
from risk_biased.models.latent_distributions import AbstractLatentDistribution
|
17 |
+
|
18 |
+
|
19 |
+
class BaseEncoderNN(nn.Module):
|
20 |
+
"""Base encoder neural network that defines the common functionality of encoders.
|
21 |
+
It should not be used directly but rather extended to define specific encoders.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
params: dataclass defining the necessary parameters
|
25 |
+
num_steps: length of the input sequence
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
params: CVAEParams,
|
31 |
+
latent_dim: int,
|
32 |
+
num_steps: int,
|
33 |
+
) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.is_mlp_residual = params.is_mlp_residual
|
36 |
+
self.num_hidden_layers = params.num_hidden_layers
|
37 |
+
self.num_steps = params.num_steps
|
38 |
+
self.num_steps_future = params.num_steps_future
|
39 |
+
self.sequence_encoder_type = params.sequence_encoder_type
|
40 |
+
self.state_dim = params.state_dim
|
41 |
+
self.latent_dim = latent_dim
|
42 |
+
self.hidden_dim = params.hidden_dim
|
43 |
+
|
44 |
+
if params.sequence_encoder_type == "MLP":
|
45 |
+
self._agent_encoder = SequenceEncoderMLP(
|
46 |
+
params.state_dim,
|
47 |
+
params.hidden_dim,
|
48 |
+
params.num_hidden_layers,
|
49 |
+
num_steps,
|
50 |
+
params.is_mlp_residual,
|
51 |
+
)
|
52 |
+
elif params.sequence_encoder_type == "LSTM":
|
53 |
+
self._agent_encoder = SequenceEncoderLSTM(
|
54 |
+
params.state_dim, params.hidden_dim
|
55 |
+
)
|
56 |
+
elif params.sequence_encoder_type == "maskedLSTM":
|
57 |
+
self._agent_encoder = SequenceEncoderMaskedLSTM(
|
58 |
+
params.state_dim, params.hidden_dim
|
59 |
+
)
|
60 |
+
|
61 |
+
if params.interaction_type == "Attention" or params.interaction_type == "MAB":
|
62 |
+
self._interaction = MAB(
|
63 |
+
params.hidden_dim, params.num_attention_heads, params.num_blocks
|
64 |
+
)
|
65 |
+
elif (
|
66 |
+
params.interaction_type == "ContextGating"
|
67 |
+
or params.interaction_type == "MCG"
|
68 |
+
):
|
69 |
+
self._interaction = MCG(
|
70 |
+
params.hidden_dim,
|
71 |
+
params.mcg_dim_expansion,
|
72 |
+
params.mcg_num_layers,
|
73 |
+
params.num_blocks,
|
74 |
+
params.is_mlp_residual,
|
75 |
+
)
|
76 |
+
elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB":
|
77 |
+
self._interaction = MHB(
|
78 |
+
params.hidden_dim,
|
79 |
+
params.num_attention_heads,
|
80 |
+
params.mcg_dim_expansion,
|
81 |
+
params.mcg_num_layers,
|
82 |
+
params.num_blocks,
|
83 |
+
params.is_mlp_residual,
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
self._interaction = lambda x, *args, **kwargs: x
|
87 |
+
self._output_layer = nn.Linear(params.hidden_dim, self.latent_dim)
|
88 |
+
|
89 |
+
def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs):
|
90 |
+
raise NotImplementedError
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
x: torch.Tensor,
|
95 |
+
mask_x: torch.Tensor,
|
96 |
+
encoded_absolute: torch.Tensor,
|
97 |
+
encoded_map: torch.Tensor,
|
98 |
+
mask_map: torch.Tensor,
|
99 |
+
y: Optional[torch.Tensor] = None,
|
100 |
+
mask_y: Optional[torch.Tensor] = None,
|
101 |
+
x_ego: Optional[torch.Tensor] = None,
|
102 |
+
y_ego: Optional[torch.Tensor] = None,
|
103 |
+
offset: Optional[torch.Tensor] = None,
|
104 |
+
risk_level: Optional[torch.Tensor] = None,
|
105 |
+
) -> torch.Tensor:
|
106 |
+
"""Forward function that encodes input tensors into an output tensor of dimension
|
107 |
+
latent_dim.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
111 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
112 |
+
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
113 |
+
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
|
114 |
+
mask_map: (batch_size, num_objects) tensor of bool mask
|
115 |
+
y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
116 |
+
mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None.
|
117 |
+
x_ego: (batch_size, 1, num_steps, state_dim) ego history
|
118 |
+
y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
|
119 |
+
offset (optional): (batch_size, num_agents, state_dim) offset position from ego.
|
120 |
+
risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
|
121 |
+
trajectories. Defaults to None.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
(batch_size, num_agents, latent_dim) output tensor
|
125 |
+
"""
|
126 |
+
h_agents = self.encode_agents(
|
127 |
+
x=x,
|
128 |
+
mask_x=mask_x,
|
129 |
+
y=y,
|
130 |
+
mask_y=mask_y,
|
131 |
+
x_ego=x_ego,
|
132 |
+
y_ego=y_ego,
|
133 |
+
offset=offset,
|
134 |
+
risk_level=risk_level,
|
135 |
+
)
|
136 |
+
mask_agent = mask_x.any(-1)
|
137 |
+
h_agents = self._interaction(
|
138 |
+
h_agents, mask_agent, encoded_absolute, encoded_map, mask_map
|
139 |
+
)
|
140 |
+
|
141 |
+
return self._output_layer(h_agents)
|
142 |
+
|
143 |
+
|
144 |
+
class BiasedEncoderNN(BaseEncoderNN):
|
145 |
+
"""Biased encoder neural network that encodes past info and auxiliary input
|
146 |
+
into a biased distribution over the latent space.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
params: dataclass defining the necessary parameters
|
150 |
+
num_steps: length of the input sequence
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
params: CVAEParams,
|
156 |
+
latent_dim: int,
|
157 |
+
num_steps: int,
|
158 |
+
) -> None:
|
159 |
+
super().__init__(params, latent_dim, num_steps)
|
160 |
+
self.condition_on_ego_future = params.condition_on_ego_future
|
161 |
+
if params.sequence_encoder_type == "MLP":
|
162 |
+
self._ego_encoder = SequenceEncoderMLP(
|
163 |
+
params.state_dim,
|
164 |
+
params.hidden_dim,
|
165 |
+
params.num_hidden_layers,
|
166 |
+
params.num_steps
|
167 |
+
+ params.num_steps_future * self.condition_on_ego_future,
|
168 |
+
params.is_mlp_residual,
|
169 |
+
)
|
170 |
+
elif params.sequence_encoder_type == "LSTM":
|
171 |
+
self._ego_encoder = SequenceEncoderLSTM(params.state_dim, params.hidden_dim)
|
172 |
+
elif params.sequence_encoder_type == "maskedLSTM":
|
173 |
+
self._ego_encoder = SequenceEncoderMaskedLSTM(
|
174 |
+
params.state_dim, params.hidden_dim
|
175 |
+
)
|
176 |
+
|
177 |
+
self._auxiliary_encode = nn.Linear(
|
178 |
+
params.hidden_dim + 1 + params.hidden_dim, params.hidden_dim
|
179 |
+
)
|
180 |
+
|
181 |
+
def biased_parameters(self, recurse: bool = True):
|
182 |
+
"""Get the parameters to be optimized when training to bias."""
|
183 |
+
yield from self.parameters(recurse)
|
184 |
+
|
185 |
+
def encode_agents(
|
186 |
+
self,
|
187 |
+
x: torch.Tensor,
|
188 |
+
mask_x: torch.Tensor,
|
189 |
+
*,
|
190 |
+
x_ego: torch.Tensor,
|
191 |
+
y_ego: torch.Tensor,
|
192 |
+
offset: torch.Tensor,
|
193 |
+
risk_level: torch.Tensor,
|
194 |
+
**kwargs,
|
195 |
+
):
|
196 |
+
"""Encode agent input and auxiliary input into a feature vector.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
200 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
201 |
+
x_ego: (batch_size, 1, num_steps, state_dim) ego history
|
202 |
+
y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
|
203 |
+
offset: (batch_size, num_agents, state_dim) offset position from ego.
|
204 |
+
risk_level: (batch_size, num_agents) tensor of risk levels desired for future
|
205 |
+
trajectories. Defaults to None.
|
206 |
+
Returns:
|
207 |
+
(batch_size, latent_dim) output tensor
|
208 |
+
"""
|
209 |
+
|
210 |
+
if self.condition_on_ego_future:
|
211 |
+
ego_tensor = torch.cat([x_ego, y_ego], dim=-2)
|
212 |
+
else:
|
213 |
+
ego_tensor = x_ego
|
214 |
+
|
215 |
+
risk_feature = ((risk_level - 0.5) * 10).exp().unsqueeze(-1)
|
216 |
+
mask_ego = torch.ones(
|
217 |
+
ego_tensor.shape[0],
|
218 |
+
offset.shape[1],
|
219 |
+
ego_tensor.shape[2],
|
220 |
+
device=ego_tensor.device,
|
221 |
+
)
|
222 |
+
batch_size, n_agents, dynamic_state_dim = offset.shape
|
223 |
+
state_dim = ego_tensor.shape[-1]
|
224 |
+
extended_offset = torch.cat(
|
225 |
+
(
|
226 |
+
offset,
|
227 |
+
torch.zeros(
|
228 |
+
batch_size,
|
229 |
+
n_agents,
|
230 |
+
state_dim - dynamic_state_dim,
|
231 |
+
device=offset.device,
|
232 |
+
),
|
233 |
+
),
|
234 |
+
dim=-1,
|
235 |
+
).unsqueeze(-2)
|
236 |
+
if extended_offset.shape[1] > 1:
|
237 |
+
ego_encoded = self._ego_encoder(
|
238 |
+
ego_tensor + extended_offset[:, :1] - extended_offset, mask_ego
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
ego_encoded = self._ego_encoder(ego_tensor - extended_offset, mask_ego)
|
242 |
+
auxiliary_input = torch.cat((risk_feature, ego_encoded), -1)
|
243 |
+
|
244 |
+
h_agents = self._agent_encoder(x, mask_x)
|
245 |
+
h_agents = torch.cat([h_agents, auxiliary_input], dim=-1)
|
246 |
+
h_agents = self._auxiliary_encode(h_agents)
|
247 |
+
|
248 |
+
return h_agents
|
249 |
+
|
250 |
+
|
251 |
+
class InferenceEncoderNN(BaseEncoderNN):
|
252 |
+
"""Inference encoder neural network that encodes past info into the
|
253 |
+
inference distribution over the latent space.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
params: dataclass defining the necessary parameters
|
257 |
+
num_steps: length of the input sequence
|
258 |
+
"""
|
259 |
+
|
260 |
+
def biaser_parameters(self, recurse: bool = True):
|
261 |
+
yield from []
|
262 |
+
|
263 |
+
def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs):
|
264 |
+
h_agents = self._agent_encoder(x, mask_x)
|
265 |
+
return h_agents
|
266 |
+
|
267 |
+
|
268 |
+
class FutureEncoderNN(BaseEncoderNN):
|
269 |
+
"""Future encoder neural network that encodes past and future info into the
|
270 |
+
future-conditioned distribution over the latent space.
|
271 |
+
The future is not available at test time, this is only used for training.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
params: dataclass defining the necessary parameters
|
275 |
+
num_steps: length of the input sequence
|
276 |
+
|
277 |
+
"""
|
278 |
+
|
279 |
+
def biaser_parameters(self, recurse: bool = True):
|
280 |
+
"""The future encoder is not optimized when training to bias."""
|
281 |
+
yield from []
|
282 |
+
|
283 |
+
def encode_agents(
|
284 |
+
self,
|
285 |
+
x: torch.Tensor,
|
286 |
+
mask_x: torch.Tensor,
|
287 |
+
*,
|
288 |
+
y: torch.Tensor,
|
289 |
+
mask_y: torch.Tensor,
|
290 |
+
**kwargs,
|
291 |
+
):
|
292 |
+
"""Encode agent input and future input into a feature vector.
|
293 |
+
Args:
|
294 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of trajectory history
|
295 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
296 |
+
y: (batch_size, num_agents, num_steps_future, state_dim) future trajectory
|
297 |
+
mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask
|
298 |
+
"""
|
299 |
+
mask_traj = torch.cat([mask_x, mask_y], dim=-1)
|
300 |
+
h_agents = self._agent_encoder(torch.cat([x, y], dim=-2), mask_traj)
|
301 |
+
return h_agents
|
302 |
+
|
303 |
+
|
304 |
+
class CVAEEncoder(nn.Module):
|
305 |
+
"""Encoder architecture for conditional variational autoencoder
|
306 |
+
|
307 |
+
Args:
|
308 |
+
model: encoder neural network that transforms input tensors to an unsplitted latent output
|
309 |
+
latent_distribution_creator: Class that creates a latent distribution class for the latent space.
|
310 |
+
"""
|
311 |
+
|
312 |
+
def __init__(
|
313 |
+
self,
|
314 |
+
model: BaseEncoderNN,
|
315 |
+
latent_distribution_creator,
|
316 |
+
) -> None:
|
317 |
+
super().__init__()
|
318 |
+
self._model = model
|
319 |
+
self.latent_dim = model.latent_dim
|
320 |
+
self._latent_distribution_creator = latent_distribution_creator
|
321 |
+
|
322 |
+
def biased_parameters(self, recurse: bool = True):
|
323 |
+
yield from self._model.biased_parameters(recurse)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
x: torch.Tensor,
|
328 |
+
mask_x: torch.Tensor,
|
329 |
+
encoded_absolute: torch.Tensor,
|
330 |
+
encoded_map: torch.Tensor,
|
331 |
+
mask_map: torch.Tensor,
|
332 |
+
y: Optional[torch.Tensor] = None,
|
333 |
+
mask_y: Optional[torch.Tensor] = None,
|
334 |
+
x_ego: Optional[torch.Tensor] = None,
|
335 |
+
y_ego: Optional[torch.Tensor] = None,
|
336 |
+
offset: Optional[torch.Tensor] = None,
|
337 |
+
risk_level: Optional[torch.Tensor] = None,
|
338 |
+
) -> AbstractLatentDistribution:
|
339 |
+
"""Forward function that encodes input tensors into an output tensor of dimension
|
340 |
+
latent_dim.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
344 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
345 |
+
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
346 |
+
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
|
347 |
+
mask_map: (batch_size, num_objects) tensor of bool mask
|
348 |
+
y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
349 |
+
mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None.
|
350 |
+
x_ego (optional): (batch_size, 1, num_steps, state_dim) ego history
|
351 |
+
y_ego (optional): (batch_size, 1, num_steps_future, state_dim) ego future
|
352 |
+
offset (optional): (batch_size, num_agents, state_dim) offset position from ego.
|
353 |
+
risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
|
354 |
+
trajectories. Defaults to None.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
Latent distribution representing the posterior over the latent variables given the input observations.
|
358 |
+
"""
|
359 |
+
|
360 |
+
latent_output = self._model(
|
361 |
+
x=x,
|
362 |
+
mask_x=mask_x,
|
363 |
+
encoded_absolute=encoded_absolute,
|
364 |
+
encoded_map=encoded_map,
|
365 |
+
mask_map=mask_map,
|
366 |
+
y=y,
|
367 |
+
mask_y=mask_y,
|
368 |
+
x_ego=x_ego,
|
369 |
+
y_ego=y_ego,
|
370 |
+
offset=offset,
|
371 |
+
risk_level=risk_level,
|
372 |
+
)
|
373 |
+
|
374 |
+
latent_distribution = self._latent_distribution_creator(latent_output)
|
375 |
+
|
376 |
+
return latent_distribution
|
risk_biased/models/cvae_params.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from mmcv import Config
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class CVAEParams:
|
8 |
+
"""
|
9 |
+
state_dim: Dimension of the state at each time step.
|
10 |
+
map_state_dim: Dimension of the map point features at each position.
|
11 |
+
num_steps: Number of time steps in the past trajectory input.
|
12 |
+
num_steps_future: Number of time steps in the future trajectory output.
|
13 |
+
latent_dim: Dimension of the latent space
|
14 |
+
hidden_dim: Dimension of the hidden layers
|
15 |
+
num_hidden_layers: Number of layers for each model, (encoder, decoder)
|
16 |
+
is_mlp_residual: Set to True to add linear transformation of the input to output of the MLP
|
17 |
+
interaction_type: Wether to use MCG, MAB, or MHB to handle interactions
|
18 |
+
num_attention_heads: Number of attention heads to use in MHA blocks
|
19 |
+
mcg_dim_expansion: Dimension expansion factor for the MCG global interaction space
|
20 |
+
mcg_num_layers: Number of layers for the MLP MCG blocks
|
21 |
+
num_blocks: Number of interaction blocks to use
|
22 |
+
sequence_encoder_type: Type of sequence encoder maskedLSTM, LSTM, or MLP
|
23 |
+
sequence_decoder_type: Type of sequence decoder maskedLSTM, LSTM, or MLP
|
24 |
+
condition_on_ego_future: Wether to condition the biasing with the ego future or only the ego past
|
25 |
+
latent_regularization: Weight of the latent regularization loss
|
26 |
+
"""
|
27 |
+
|
28 |
+
dt: float
|
29 |
+
state_dim: int
|
30 |
+
dynamic_state_dim: int
|
31 |
+
map_state_dim: int
|
32 |
+
max_size_lane: int
|
33 |
+
num_steps: int
|
34 |
+
num_steps_future: int
|
35 |
+
latent_dim: int
|
36 |
+
hidden_dim: int
|
37 |
+
num_hidden_layers: int
|
38 |
+
is_mlp_residual: bool
|
39 |
+
interaction_type: int
|
40 |
+
num_attention_heads: int
|
41 |
+
mcg_dim_expansion: int
|
42 |
+
mcg_num_layers: int
|
43 |
+
num_blocks: int
|
44 |
+
sequence_encoder_type: str
|
45 |
+
sequence_decoder_type: str
|
46 |
+
condition_on_ego_future: bool
|
47 |
+
latent_regularization: float
|
48 |
+
risk_assymetry_factor: float
|
49 |
+
num_vq: int
|
50 |
+
latent_distribution: str
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def from_config(cfg: Config):
|
54 |
+
return CVAEParams(
|
55 |
+
dt=cfg.dt,
|
56 |
+
state_dim=cfg.state_dim,
|
57 |
+
dynamic_state_dim=cfg.dynamic_state_dim,
|
58 |
+
map_state_dim=cfg.map_state_dim,
|
59 |
+
max_size_lane=cfg.max_size_lane,
|
60 |
+
num_steps=cfg.num_steps,
|
61 |
+
num_steps_future=cfg.num_steps_future,
|
62 |
+
latent_dim=cfg.latent_dim,
|
63 |
+
hidden_dim=cfg.hidden_dim,
|
64 |
+
num_hidden_layers=cfg.num_hidden_layers,
|
65 |
+
is_mlp_residual=cfg.is_mlp_residual,
|
66 |
+
interaction_type=cfg.interaction_type,
|
67 |
+
mcg_dim_expansion=cfg.mcg_dim_expansion,
|
68 |
+
mcg_num_layers=cfg.mcg_num_layers,
|
69 |
+
num_blocks=cfg.num_blocks,
|
70 |
+
num_attention_heads=cfg.num_attention_heads,
|
71 |
+
sequence_encoder_type=cfg.sequence_encoder_type,
|
72 |
+
sequence_decoder_type=cfg.sequence_decoder_type,
|
73 |
+
condition_on_ego_future=cfg.condition_on_ego_future,
|
74 |
+
latent_regularization=cfg.latent_regularization,
|
75 |
+
risk_assymetry_factor=cfg.risk_assymetry_factor,
|
76 |
+
num_vq=cfg.num_vq,
|
77 |
+
latent_distribution=cfg.latent_distribution,
|
78 |
+
)
|
risk_biased/models/latent_distributions.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Callable, Tuple
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
def relaxed_one_hot_categorical_without_replacement(temperature, logits, num_samples=1):
|
11 |
+
# See paper Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement (https://arxiv.org/pdf/1903.06059.pdf)
|
12 |
+
# for explanation of the trick
|
13 |
+
scores = (
|
14 |
+
(torch.distributions.Gumbel(logits, 1).rsample() / temperature)
|
15 |
+
.softmax(-1)
|
16 |
+
.clamp_min(1e-10)
|
17 |
+
)
|
18 |
+
top_scores, top_indices = torch.topk(
|
19 |
+
scores,
|
20 |
+
num_samples,
|
21 |
+
dim=-1,
|
22 |
+
)
|
23 |
+
return scores, top_indices
|
24 |
+
|
25 |
+
|
26 |
+
class AbstractLatentDistribution(nn.Module, ABC):
|
27 |
+
"""Base class for latent distribution"""
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def sample(
|
31 |
+
self, num_samples: int, *args, **kwargs
|
32 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
33 |
+
"""Sample from the latent distribution."""
|
34 |
+
|
35 |
+
@abstractmethod
|
36 |
+
def kl_loss(
|
37 |
+
self,
|
38 |
+
other: "GaussianLatentDistribution",
|
39 |
+
threshold: float = 0,
|
40 |
+
mask_z: Optional[torch.Tensor] = None,
|
41 |
+
) -> torch.Tensor:
|
42 |
+
"""Compute the KL divergence between two latent distributions."""
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def sampling_loss(self) -> torch.Tensor:
|
46 |
+
"""Loss of the latent distribution."""
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def average(
|
50 |
+
self, other: "AbstractLatentDistribution", weight_other: torch.Tensor
|
51 |
+
) -> "AbstractLatentDistribution":
|
52 |
+
"""Average of the latent distribution."""
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def log_dict(self, type: str) -> dict:
|
56 |
+
"""Log the latent distribution values."""
|
57 |
+
|
58 |
+
|
59 |
+
class GaussianLatentDistribution(AbstractLatentDistribution):
|
60 |
+
"""Gaussian latent distribution"""
|
61 |
+
|
62 |
+
def __init__(self, latent_representation: torch.Tensor):
|
63 |
+
super().__init__()
|
64 |
+
mu, logvar = torch.chunk(latent_representation, 2, dim=-1)
|
65 |
+
self.register_buffer("mu", mu, False)
|
66 |
+
self.register_buffer("logvar", logvar, False)
|
67 |
+
|
68 |
+
def sample(
|
69 |
+
self, n_samples: int = 0, *args, **kwargs
|
70 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
71 |
+
"""Sample from Gaussian with a reparametrization trick
|
72 |
+
|
73 |
+
Args:
|
74 |
+
n_samples (optional): number of samples to make, (if 0 one sample with no extra
|
75 |
+
dimension). Defaults to 0.
|
76 |
+
Returns:
|
77 |
+
Random Gaussian sample of size (some_shape, (n_samples), latent_dim)
|
78 |
+
"""
|
79 |
+
|
80 |
+
std = (self.logvar / 2).exp()
|
81 |
+
if n_samples <= 0:
|
82 |
+
eps = torch.randn_like(std)
|
83 |
+
latent_samples = self.mu + eps * std
|
84 |
+
weights = torch.ones_like(latent_samples[..., 0])
|
85 |
+
else:
|
86 |
+
eps = torch.randn(
|
87 |
+
[*std.shape[:-1], n_samples, self.mu.shape[-1]], device=std.device
|
88 |
+
)
|
89 |
+
# Reshape
|
90 |
+
latent_samples = self.mu.unsqueeze(-2) + eps * std.unsqueeze(-2)
|
91 |
+
weights = torch.ones_like(latent_samples[..., 0]) / n_samples
|
92 |
+
return latent_samples, weights
|
93 |
+
|
94 |
+
def kl_loss(
|
95 |
+
self,
|
96 |
+
other: "GaussianLatentDistribution",
|
97 |
+
threshold: float = 0,
|
98 |
+
mask_z: Optional[torch.Tensor] = None,
|
99 |
+
) -> torch.Tensor:
|
100 |
+
"""Compute the KL divergence between two latent distributions."""
|
101 |
+
assert type(other) == GaussianLatentDistribution
|
102 |
+
kl_loss = (
|
103 |
+
(other.logvar
|
104 |
+
- self.logvar
|
105 |
+
+ ((self.mu - other.mu).square() + self.logvar.exp()) / other.logvar.exp()
|
106 |
+
- 1)*0.5
|
107 |
+
).clamp_min(threshold)
|
108 |
+
if mask_z is None:
|
109 |
+
return kl_loss.mean()
|
110 |
+
else:
|
111 |
+
assert mask_z.any()
|
112 |
+
return torch.sum(kl_loss.mean(-1) * mask_z) / torch.sum(mask_z)
|
113 |
+
|
114 |
+
def sampling_loss(self) -> torch.Tensor:
|
115 |
+
return torch.zeros(1, device=self.mu.device)
|
116 |
+
|
117 |
+
def average(
|
118 |
+
self, other: "GaussianLatentDistribution", weight_other: torch.Tensor
|
119 |
+
) -> "GaussianLatentDistribution":
|
120 |
+
assert type(other) == GaussianLatentDistribution
|
121 |
+
assert other.mu.shape == self.mu.shape
|
122 |
+
average_log_var = (
|
123 |
+
self.logvar.exp() * (1 - weight_other) + other.logvar.exp() * weight_other
|
124 |
+
).log()
|
125 |
+
return GaussianLatentDistribution(
|
126 |
+
torch.cat(
|
127 |
+
(
|
128 |
+
self.mu * (1 - weight_other) + other.mu * weight_other,
|
129 |
+
average_log_var,
|
130 |
+
),
|
131 |
+
dim=-1,
|
132 |
+
)
|
133 |
+
)
|
134 |
+
|
135 |
+
def log_dict(self, type: str) -> dict:
|
136 |
+
return {
|
137 |
+
f"latent/{type}/abs_mean": self.mu.abs().mean(),
|
138 |
+
f"latent/{type}/std": (self.logvar * 0.5).exp().mean(),
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
class QuantizedLatentDistribution(AbstractLatentDistribution):
|
143 |
+
"""Quantized latent distribution.
|
144 |
+
It is defined with a codebook of quantized latents and a continuous latent.
|
145 |
+
The distribution is based on distances of the continuous latent to the codebook.
|
146 |
+
Sampling is only quantizing the continuous latent.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
continuous_latent : Continuous latent representation of shape (some_shape, latent_dim)
|
150 |
+
codebook : Codebook of shape (num_embeddings, latent_dim)
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
continuous_latent: torch.Tensor,
|
156 |
+
codebook: torch.Tensor,
|
157 |
+
flush_weights: Callable[[], None],
|
158 |
+
get_weights: Callable[[], torch.Tensor],
|
159 |
+
index_add_one_weights: Callable[[torch.Tensor], None],
|
160 |
+
):
|
161 |
+
super().__init__()
|
162 |
+
self.register_buffer("continuous_latent", continuous_latent, False)
|
163 |
+
self.register_buffer("codebook", codebook, False)
|
164 |
+
self.flush_weights = flush_weights
|
165 |
+
self.get_weights = get_weights
|
166 |
+
self.index_add_one_weights = index_add_one_weights
|
167 |
+
self.quantization_loss = None
|
168 |
+
self.accuracy = None
|
169 |
+
|
170 |
+
def sample(
|
171 |
+
self, n_samples: int = 0, *args, **kwargs
|
172 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
173 |
+
"""Quantize the continuous latent from the latent dictionary.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
latent: (batch_size, num_agents, latent_dim) Continuous latent input
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
quantized_latent, quantization_loss
|
180 |
+
"""
|
181 |
+
assert n_samples == 0, "Only one sample is supported for quantized latent"
|
182 |
+
|
183 |
+
distances_to_quantized = (
|
184 |
+
(
|
185 |
+
self.codebook.view(1, 1, *self.codebook.shape)
|
186 |
+
- self.continuous_latent.unsqueeze(-2)
|
187 |
+
)
|
188 |
+
.square()
|
189 |
+
.sum(-1)
|
190 |
+
)
|
191 |
+
batch_size, num_agents, num_vq = distances_to_quantized.shape
|
192 |
+
|
193 |
+
self.soft_one_hot = (
|
194 |
+
(-100 * distances_to_quantized)
|
195 |
+
.softmax(dim=-1)
|
196 |
+
.view(batch_size, num_agents, num_vq)
|
197 |
+
)
|
198 |
+
# quantized, args_selected = self.sample(soft_one_hot)
|
199 |
+
_, args_selected = torch.min(distances_to_quantized, dim=-1)
|
200 |
+
quantized = self.codebook[args_selected, :]
|
201 |
+
args_selected = args_selected.view(-1)
|
202 |
+
|
203 |
+
# Update weights
|
204 |
+
self.index_add_one_weights(args_selected)
|
205 |
+
|
206 |
+
distances_to_quantized = distances_to_quantized.view(
|
207 |
+
batch_size * num_agents, num_vq
|
208 |
+
)
|
209 |
+
|
210 |
+
# Resample useless latent vectors
|
211 |
+
random_latents = self.continuous_latent.view(
|
212 |
+
batch_size * num_agents, self.codebook.shape[-1]
|
213 |
+
)[torch.randint(batch_size * num_agents, (num_vq,))]
|
214 |
+
codebook_weights = self.get_weights()
|
215 |
+
total_samples = codebook_weights.sum()
|
216 |
+
# TODO: The value 100 is arbitrary, should it be a parameter?
|
217 |
+
# The uselessness of a codebook vector is defined by the number of times it has been sampled
|
218 |
+
# if it has been sampled less than 1% of the time, it is pushed towards a random continuous latent sample
|
219 |
+
# this prevents the codebook from being dominated by a few vectors
|
220 |
+
self.uselessness = (
|
221 |
+
(
|
222 |
+
torch.where(
|
223 |
+
(codebook_weights < total_samples / (100 * num_vq)).unsqueeze(-1),
|
224 |
+
random_latents.detach() - self.codebook,
|
225 |
+
torch.zeros_like(self.codebook),
|
226 |
+
).abs()
|
227 |
+
+ 1
|
228 |
+
)
|
229 |
+
.log()
|
230 |
+
.sum(-1)
|
231 |
+
.mean()
|
232 |
+
)
|
233 |
+
# TODO: The value 1e6 is arbitrary, should it be a parameter?
|
234 |
+
if total_samples > 1e6 * num_vq:
|
235 |
+
# Flush the codebook weights when the number of samples is too high
|
236 |
+
# This prevents the codebook from being dominated by its history
|
237 |
+
# if a few vectors were visited a lot and also prevents overflows
|
238 |
+
self.flush_weights()
|
239 |
+
|
240 |
+
# commit_loss = (self.continuous_latent - quantized.detach()).square().clamp_min(self.distance_threshold).sum(-1).mean()
|
241 |
+
|
242 |
+
self.quantization_loss = (
|
243 |
+
(self.continuous_latent - quantized).square().sum(-1).mean()
|
244 |
+
)
|
245 |
+
|
246 |
+
quantized = (
|
247 |
+
quantized.detach()
|
248 |
+
+ self.continuous_latent
|
249 |
+
- self.continuous_latent.detach()
|
250 |
+
)
|
251 |
+
|
252 |
+
self.latent_diversity = (
|
253 |
+
(self.continuous_latent[None, ...] - self.continuous_latent[:, None, ...])
|
254 |
+
.square()
|
255 |
+
.sum(-1)
|
256 |
+
.mean()
|
257 |
+
)
|
258 |
+
|
259 |
+
return quantized, torch.ones_like(quantized[..., 0]) / num_vq
|
260 |
+
|
261 |
+
def kl_loss(
|
262 |
+
self,
|
263 |
+
other: "ClassifiedLatentDistribution",
|
264 |
+
threshold: float = 0,
|
265 |
+
mask_z: Optional[torch.Tensor] = None,
|
266 |
+
) -> torch.Tensor:
|
267 |
+
"""Compute the cross entropy between two latent distributions."""
|
268 |
+
assert type(other) == ClassifiedLatentDistribution
|
269 |
+
min_logits = -10
|
270 |
+
max_logits = 10
|
271 |
+
pred_log = other.logits.clamp(min_logits, max_logits).log_softmax(-1)
|
272 |
+
self_pred = self.soft_one_hot
|
273 |
+
self.accuracy = (self_pred.argmax(-1) == other.logits.argmax(-1)).float().mean()
|
274 |
+
return -2 * (pred_log * self_pred).sum(-1).mean()
|
275 |
+
|
276 |
+
def sampling_loss(self) -> torch.Tensor:
|
277 |
+
if self.quantization_loss is None:
|
278 |
+
self.sample()
|
279 |
+
return 0.5 * (
|
280 |
+
self.quantization_loss + self.uselessness + 0.001 * self.latent_diversity
|
281 |
+
)
|
282 |
+
|
283 |
+
def average(
|
284 |
+
self, other: "QuantizedLatentDistribution", weight_other: torch.Tensor
|
285 |
+
) -> "QuantizedLatentDistribution":
|
286 |
+
raise NotImplementedError(
|
287 |
+
"Average is not implemented for QuantizedLatentDistribution"
|
288 |
+
)
|
289 |
+
|
290 |
+
def log_dict(self, type: str) -> dict:
|
291 |
+
log_dict = {
|
292 |
+
f"latent/{type}/quantization_loss": self.quantization_loss,
|
293 |
+
f"latent/{type}/uselessness": self.uselessness,
|
294 |
+
f"latent/{type}/latent_diversity": self.latent_diversity,
|
295 |
+
f"latent/{type}/codebook_abs_mean": self.codebook.abs().mean(),
|
296 |
+
f"latent/{type}/codebook_std": self.codebook.std(),
|
297 |
+
f"latent/{type}/latent_abs_mean": self.continuous_latent.abs().mean(),
|
298 |
+
f"latent/{type}/latent_std": self.continuous_latent.std(),
|
299 |
+
}
|
300 |
+
if self.accuracy is not None:
|
301 |
+
log_dict[f"latent/{type}/accuracy"] = self.accuracy
|
302 |
+
return log_dict
|
303 |
+
|
304 |
+
|
305 |
+
class ClassifiedLatentDistribution(AbstractLatentDistribution):
|
306 |
+
"""Classified latent distribution.
|
307 |
+
It is defined with a codebook of quantized latents and a probability distribution over the codebook elements.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
logits : Logits of shape (some_shape, num_embeddings)
|
311 |
+
codebook : Codebook of shape (num_embeddings, latent_dim)
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self, logits: torch.Tensor, codebook: torch.Tensor):
|
315 |
+
super().__init__()
|
316 |
+
self.register_buffer("logits", logits, persistent=False)
|
317 |
+
self.register_buffer("codebook", codebook, persistent=False)
|
318 |
+
|
319 |
+
def sample(
|
320 |
+
self, n_samples: int = 0, replacement: bool = True, *args, **kwargs
|
321 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
322 |
+
batch_size, num_agents, num_vq = self.logits.shape
|
323 |
+
squeeze_out = False
|
324 |
+
if n_samples == 0:
|
325 |
+
squeeze_out = True
|
326 |
+
n_samples = 1
|
327 |
+
elif n_samples > self.codebook.shape[0]:
|
328 |
+
warnings.warn(
|
329 |
+
f"Requested {n_samples} samples but only {self.codebook.shape[0]} are available in the descrete latent space. Switching to replacement=True to support it."
|
330 |
+
)
|
331 |
+
replacement = True
|
332 |
+
|
333 |
+
if self.training:
|
334 |
+
# TODO: should we make the temperature a parameter?
|
335 |
+
all_weights, indices = relaxed_one_hot_categorical_without_replacement(
|
336 |
+
logits=self.logits, temperature=1, num_samples=n_samples
|
337 |
+
)
|
338 |
+
selected_latents = self.codebook[indices, :]
|
339 |
+
# Cumulative mask of indices that have been sampled in order of probability
|
340 |
+
mask_selection = torch.nn.functional.one_hot(indices, num_vq).cumsum(-2)
|
341 |
+
mask_selection[..., 1:, :] = mask_selection[..., :-1, :]
|
342 |
+
mask_selection[..., 0, :] = 0.0
|
343 |
+
# Remove the probability of previous samples to account for sampling without replacement
|
344 |
+
masked_weights = all_weights.unsqueeze(-2) * (1 - mask_selection.float())
|
345 |
+
# Renormalize the probabilities to sum to 1
|
346 |
+
masked_weights = masked_weights / masked_weights.sum(-1, keepdim=True)
|
347 |
+
|
348 |
+
latent_samples = (
|
349 |
+
masked_weights.unsqueeze(-1)
|
350 |
+
* self.codebook[None, None, None, ...].detach()
|
351 |
+
).sum(-2)
|
352 |
+
latent_samples = (
|
353 |
+
selected_latents.detach() + latent_samples - latent_samples.detach()
|
354 |
+
)
|
355 |
+
probs = torch.gather(self.logits.softmax(-1), -1, indices)
|
356 |
+
else:
|
357 |
+
probs = self.logits.softmax(-1)
|
358 |
+
samples = torch.multinomial(
|
359 |
+
probs.view(batch_size * num_agents, num_vq),
|
360 |
+
n_samples,
|
361 |
+
replacement=replacement,
|
362 |
+
)
|
363 |
+
latent_samples = self.codebook[samples]
|
364 |
+
probs = torch.gather(
|
365 |
+
probs, -1, samples.view(batch_size, num_agents, num_vq)
|
366 |
+
)
|
367 |
+
|
368 |
+
if squeeze_out:
|
369 |
+
latent_samples = latent_samples.view(
|
370 |
+
batch_size, num_agents, self.codebook.shape[-1]
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
latent_samples = latent_samples.view(
|
374 |
+
batch_size, num_agents, n_samples, self.codebook.shape[-1]
|
375 |
+
)
|
376 |
+
return latent_samples, probs
|
377 |
+
|
378 |
+
def kl_loss(
|
379 |
+
self,
|
380 |
+
other: "ClassifiedLatentDistribution",
|
381 |
+
threshold: float = 0,
|
382 |
+
mask_z: Optional[torch.Tensor] = None,
|
383 |
+
) -> torch.Tensor:
|
384 |
+
"""Compute the cross entropy between two latent distributions. Self being the reference distribution and other the distribution to compare."""
|
385 |
+
assert type(other) == ClassifiedLatentDistribution
|
386 |
+
min_logits = -10
|
387 |
+
max_logits = 10
|
388 |
+
pred_log = other.logits.clamp(min_logits, max_logits).log_softmax(-1)
|
389 |
+
self_pred = (
|
390 |
+
(0.5 * (self.logits.detach() + self.logits))
|
391 |
+
.clamp(min_logits, max_logits)
|
392 |
+
.softmax(-1)
|
393 |
+
)
|
394 |
+
return -2 * (pred_log * self_pred).sum(-1).mean()
|
395 |
+
|
396 |
+
def sampling_loss(self) -> torch.Tensor:
|
397 |
+
return torch.zeros(1, device=self.logits.device)
|
398 |
+
|
399 |
+
def average(
|
400 |
+
self, other: "ClassifiedLatentDistribution", weight_other: torch.Tensor
|
401 |
+
) -> "ClassifiedLatentDistribution":
|
402 |
+
assert type(other) == ClassifiedLatentDistribution
|
403 |
+
assert (self.codebook == other.codebook).all()
|
404 |
+
return ClassifiedLatentDistribution(
|
405 |
+
(
|
406 |
+
self.logits.exp() * (1 - weight_other)
|
407 |
+
+ other.logits.exp() * weight_other
|
408 |
+
).log(),
|
409 |
+
self.codebook,
|
410 |
+
)
|
411 |
+
|
412 |
+
def log_dict(self, type: str) -> dict:
|
413 |
+
max_probs, _ = self.logits.softmax(-1).max(-1)
|
414 |
+
return {
|
415 |
+
f"latent/{type}/codebook_abs_mean": self.codebook.abs().mean(),
|
416 |
+
f"latent/{type}/codebook_std": self.codebook.std(),
|
417 |
+
f"latent/{type}/class_max_mean": max_probs.mean(),
|
418 |
+
f"latent/{type}/class_max_std": max_probs.std(),
|
419 |
+
}
|
420 |
+
|
421 |
+
|
422 |
+
class QuantizedDistributionCreator(nn.Module):
|
423 |
+
"""Creates a distribution from a latent vector."""
|
424 |
+
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
latent_dim: int,
|
428 |
+
num_embeddings: int,
|
429 |
+
):
|
430 |
+
super().__init__()
|
431 |
+
self.latent_dim = latent_dim
|
432 |
+
self.num_embeddings = num_embeddings
|
433 |
+
self.codebook = nn.Parameter(torch.randn(num_embeddings, latent_dim))
|
434 |
+
self.register_buffer(
|
435 |
+
"codebook_weights",
|
436 |
+
torch.ones(num_embeddings, requires_grad=False),
|
437 |
+
persistent=False,
|
438 |
+
)
|
439 |
+
|
440 |
+
def _flush_codebook_weights(self):
|
441 |
+
self.codebook_weights = torch.ones_like(self.codebook_weights)
|
442 |
+
|
443 |
+
def _get_codebook_weights(self):
|
444 |
+
return self.codebook_weights
|
445 |
+
|
446 |
+
def _index_add_one_codebook_weight(self, indices: torch.Tensor):
|
447 |
+
self.codebook_weights = self.codebook_weights.index_add(
|
448 |
+
0,
|
449 |
+
indices.flatten(),
|
450 |
+
torch.ones_like(self.codebook_weights[indices]),
|
451 |
+
)
|
452 |
+
|
453 |
+
def forward(self, latent: torch.Tensor) -> AbstractLatentDistribution:
|
454 |
+
if latent.shape[-1] == self.latent_dim:
|
455 |
+
return QuantizedLatentDistribution(
|
456 |
+
latent,
|
457 |
+
self.codebook,
|
458 |
+
self._flush_codebook_weights,
|
459 |
+
self._get_codebook_weights,
|
460 |
+
self._index_add_one_codebook_weight,
|
461 |
+
)
|
462 |
+
elif latent.shape[-1] == self.num_embeddings:
|
463 |
+
return ClassifiedLatentDistribution(
|
464 |
+
latent,
|
465 |
+
self.codebook,
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
raise ValueError(f"Latent vector has wrong dimension: {latent.shape[-1]}")
|
risk_biased/models/map_encoder.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from risk_biased.models.nn_blocks import (
|
4 |
+
SequenceEncoderLSTM,
|
5 |
+
SequenceEncoderMLP,
|
6 |
+
SequenceEncoderMaskedLSTM,
|
7 |
+
)
|
8 |
+
|
9 |
+
from risk_biased.models.cvae_params import CVAEParams
|
10 |
+
from risk_biased.models.mlp import MLP
|
11 |
+
|
12 |
+
|
13 |
+
class MapEncoderNN(nn.Module):
|
14 |
+
"""MLP encoder neural network that encodes map objects.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
params: dataclass defining the necessary parameters
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, params: CVAEParams) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self._encoder = SequenceEncoderMLP(
|
23 |
+
params.map_state_dim,
|
24 |
+
params.hidden_dim,
|
25 |
+
params.num_hidden_layers,
|
26 |
+
params.max_size_lane,
|
27 |
+
params.is_mlp_residual,
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, map, mask_map):
|
31 |
+
"""Forward function encoding map object sequences of features into object features.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
35 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
|
36 |
+
"""
|
37 |
+
encoded_map = self._encoder(map, mask_map)
|
38 |
+
return encoded_map
|
risk_biased/models/mlp.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class MLP(nn.Module):
|
8 |
+
"""Basic MLP implementation with FC layers and ReLU activation.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
input_dim : dimension of the input variable
|
12 |
+
output_dim : dimension of the output variable
|
13 |
+
h_dim : dimension of a hidden layer of MLP
|
14 |
+
num_h_layers : number of hidden layers in MLP
|
15 |
+
add_residual : set to True to add input to output (res-net) set to False to have pure MLP
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
input_dim: int,
|
22 |
+
output_dim: int,
|
23 |
+
h_dim: int,
|
24 |
+
num_h_layers: int,
|
25 |
+
add_residual: bool,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
self._input_dim = input_dim
|
29 |
+
self._output_dim = output_dim
|
30 |
+
self._h_dim = h_dim
|
31 |
+
self._num_h_layers = num_h_layers
|
32 |
+
|
33 |
+
layers = OrderedDict()
|
34 |
+
if num_h_layers > 0:
|
35 |
+
layers["fc_0"] = nn.Linear(input_dim, h_dim)
|
36 |
+
layers["relu_0"] = nn.ReLU()
|
37 |
+
else:
|
38 |
+
h_dim = input_dim
|
39 |
+
for ii in range(1, num_h_layers):
|
40 |
+
layers["fc_{}".format(ii)] = nn.Linear(h_dim, h_dim)
|
41 |
+
layers["relu_{}".format(ii)] = nn.ReLU()
|
42 |
+
layers["fc_{}".format(num_h_layers)] = nn.Linear(h_dim, self._output_dim)
|
43 |
+
|
44 |
+
self.mlp = nn.Sequential(layers)
|
45 |
+
if add_residual:
|
46 |
+
self.residual_layer = nn.Linear(input_dim, output_dim)
|
47 |
+
else:
|
48 |
+
self.residual_layer = lambda x: 0
|
49 |
+
self._layer_norm = nn.LayerNorm(output_dim)
|
50 |
+
|
51 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
52 |
+
"""Forward function for MLP
|
53 |
+
|
54 |
+
Args:
|
55 |
+
input (torch.Tensor): (batch_size, input_dim) tensor
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: (batch_size, output_dim) tensor
|
59 |
+
"""
|
60 |
+
return self._layer_norm(self.mlp(input) + self.residual_layer(input))
|
risk_biased/models/multi_head_attention.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation from https://einops.rocks/pytorch-examples.html slightly changed
|
2 |
+
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
from typing import Tuple
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
|
12 |
+
class MultiHeadAttention(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a slightly modified version of the original implementation from https://einops.rocks/pytorch-examples.html of multihead attention.
|
15 |
+
It keeps the original dimension division per head and masks the attention matrix before and after the softmax to support full row masking.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
d_model: the input feature dimension of the model
|
19 |
+
n_head: the number of heads in the multihead attention
|
20 |
+
d_k: the dimension of the key and query in the multihead attention
|
21 |
+
d_v: the dimension of the value in the multihead attention
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, d_model: int, n_head: int, d_k: torch.Tensor, d_v: torch.Tensor):
|
25 |
+
super().__init__()
|
26 |
+
self.n_head = n_head
|
27 |
+
|
28 |
+
self.w_qs = nn.Linear(d_model, int(d_k / n_head) * n_head)
|
29 |
+
self.w_ks = nn.Linear(d_model, int(d_k / n_head) * n_head)
|
30 |
+
self.w_vs = nn.Linear(d_model, int(d_v / n_head) * n_head)
|
31 |
+
self.w_rs = nn.Linear(d_model, int(d_v / n_head) * n_head)
|
32 |
+
|
33 |
+
nn.init.normal_(self.w_qs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_k)))
|
34 |
+
nn.init.normal_(self.w_ks.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_k)))
|
35 |
+
nn.init.normal_(self.w_vs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_v)))
|
36 |
+
nn.init.normal_(self.w_rs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_v)))
|
37 |
+
|
38 |
+
self.fc = nn.Linear(int(d_v / n_head) * n_head, d_model)
|
39 |
+
nn.init.xavier_normal_(self.fc.weight)
|
40 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
41 |
+
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
q: torch.Tensor,
|
45 |
+
k: torch.Tensor,
|
46 |
+
v: torch.Tensor,
|
47 |
+
mask: torch.Tensor = None,
|
48 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49 |
+
"""
|
50 |
+
Compute the masked multi-head attention given the query, key and value tensors.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
q: the query tensor of shape [batch_size, number_of_agents, d_model]
|
54 |
+
k: the key tensor of shape [batch_size, number_of_objects, d_model]
|
55 |
+
v: the value tensor of shape [batch_size, number_of_objects, d_model]
|
56 |
+
mask: the mask tensor of shape [batch_size, number_of_agents, number_of_objects]
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
[
|
60 |
+
The attention output tensor of shape [batch_size, number_of_agents, d_model],
|
61 |
+
The attention matrix of shape [batch_size, number_of_agents, number_of_objects]
|
62 |
+
]
|
63 |
+
"""
|
64 |
+
residual = q.clone()
|
65 |
+
r = self.w_rs(q)
|
66 |
+
q = rearrange(self.w_qs(q), "b a (head k) -> head b a k", head=self.n_head)
|
67 |
+
k = rearrange(self.w_ks(k), "b o (head k) -> head b o k", head=self.n_head)
|
68 |
+
v = rearrange(self.w_vs(v), "b o (head v) -> head b o v", head=self.n_head)
|
69 |
+
attn = torch.einsum("hbak,hbok->hbao", [q, k]) / math.sqrt(q.shape[-1])
|
70 |
+
if mask is not None:
|
71 |
+
# b: batch, a: agent, o: object, h: head
|
72 |
+
mask = repeat(mask, "b a o -> h b a o", h=self.n_head)
|
73 |
+
attn = attn.masked_fill(mask == 0, -math.inf)
|
74 |
+
attn = torch.softmax(attn, dim=3)
|
75 |
+
# Here we need to mask again because some lines might be all -inf in the softmax which gives Nan...
|
76 |
+
attn = attn.masked_fill(mask == 0, 0)
|
77 |
+
output = torch.einsum("hbao,hbov->hbav", [attn, v])
|
78 |
+
output = rearrange(output, "head b a v -> b a (head v)")
|
79 |
+
output = self.fc(output * r)
|
80 |
+
output = self.layer_norm(output + residual)
|
81 |
+
return output, attn
|
risk_biased/models/nn_blocks.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops.layers.torch import Rearrange
|
2 |
+
from einops import rearrange, repeat
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from risk_biased.models.multi_head_attention import MultiHeadAttention
|
7 |
+
from risk_biased.models.context_gating import ContextGating
|
8 |
+
from risk_biased.models.mlp import MLP
|
9 |
+
|
10 |
+
|
11 |
+
class SequenceEncoderMaskedLSTM(nn.Module):
|
12 |
+
"""MLP followed with a masked LSTM implementation with one layer.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
input_dim : dimension of the input variable
|
16 |
+
h_dim : dimension of a hidden layer of MLP
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, input_dim: int, h_dim: int) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self._group_objects = Rearrange("b o ... -> (b o) ...")
|
22 |
+
self._embed = nn.Linear(in_features=input_dim, out_features=h_dim)
|
23 |
+
self._lstm = nn.LSTMCell(
|
24 |
+
input_size=h_dim, hidden_size=h_dim
|
25 |
+
) # expects(batch,seq,features)
|
26 |
+
self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
|
27 |
+
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
|
28 |
+
|
29 |
+
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
|
30 |
+
"""Forward function for MapEncoder
|
31 |
+
|
32 |
+
Args:
|
33 |
+
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
|
34 |
+
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: (batch_size, num_objects, output_dim) tensor
|
38 |
+
"""
|
39 |
+
|
40 |
+
batch_size, num_objects, seq_len, _ = input.shape
|
41 |
+
split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects)
|
42 |
+
|
43 |
+
input = self._group_objects(input)
|
44 |
+
mask_input = self._group_objects(mask_input)
|
45 |
+
embedded_input = self._embed(input)
|
46 |
+
|
47 |
+
# One to many encoding of the input sequence with masking for missing points
|
48 |
+
mask_input = mask_input.float()
|
49 |
+
h = mask_input[:, 0, None] * embedded_input[:, 0, :] + (
|
50 |
+
1 - mask_input[:, 0, None]
|
51 |
+
) * repeat(self.h0, "b f -> (size b) f", size=batch_size * num_objects)
|
52 |
+
c = repeat(self.c0, "b f -> (size b) f", size=batch_size * num_objects)
|
53 |
+
for i in range(seq_len):
|
54 |
+
new_input = (
|
55 |
+
mask_input[:, i, None] * embedded_input[:, i, :]
|
56 |
+
+ (1 - mask_input[:, i, None]) * h
|
57 |
+
)
|
58 |
+
h, c = self._lstm(new_input, (h, c))
|
59 |
+
return split_objects(h)
|
60 |
+
|
61 |
+
|
62 |
+
class SequenceEncoderLSTM(nn.Module):
|
63 |
+
"""MLP followed with an LSTM with one layer.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
input_dim : dimension of the input variable
|
67 |
+
h_dim : dimension of a hidden layer of MLP
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, input_dim: int, h_dim: int) -> None:
|
71 |
+
super().__init__()
|
72 |
+
self._group_objects = Rearrange("b o ... -> (b o) ...")
|
73 |
+
self._embed = nn.Linear(in_features=input_dim, out_features=h_dim)
|
74 |
+
self._lstm = nn.LSTM(
|
75 |
+
input_size=h_dim,
|
76 |
+
hidden_size=h_dim,
|
77 |
+
batch_first=True,
|
78 |
+
) # expects(batch,seq,features)
|
79 |
+
self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
|
80 |
+
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
|
81 |
+
|
82 |
+
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
|
83 |
+
"""Forward function for MapEncoder
|
84 |
+
|
85 |
+
Args:
|
86 |
+
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
|
87 |
+
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
torch.Tensor: (batch_size, num_objects, output_dim) tensor
|
91 |
+
"""
|
92 |
+
|
93 |
+
batch_size, num_objects, seq_len, _ = input.shape
|
94 |
+
split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects)
|
95 |
+
|
96 |
+
input = self._group_objects(input)
|
97 |
+
mask_input = self._group_objects(mask_input)
|
98 |
+
embedded_input = self._embed(input)
|
99 |
+
|
100 |
+
# One to many encoding of the input sequence with masking for missing points
|
101 |
+
mask_input = mask_input.float()
|
102 |
+
h = (
|
103 |
+
mask_input[:, 0, None] * embedded_input[:, 0, :]
|
104 |
+
+ (1 - mask_input[:, 0, None])
|
105 |
+
* repeat(
|
106 |
+
self.h0, "one f -> one size f", size=batch_size * num_objects
|
107 |
+
).contiguous()
|
108 |
+
)
|
109 |
+
c = repeat(
|
110 |
+
self.c0, "one f -> one size f", size=batch_size * num_objects
|
111 |
+
).contiguous()
|
112 |
+
_, (h, _) = self._lstm(embedded_input, (h, c))
|
113 |
+
# for i in range(seq_len):
|
114 |
+
# new_input = (
|
115 |
+
# mask_input[:, i, None] * embedded_input[:, i, :]
|
116 |
+
# + (1 - mask_input[:, i, None]) * h
|
117 |
+
# )
|
118 |
+
# h, c = self._lstm(new_input, (h, c))
|
119 |
+
return split_objects(h.squeeze(0))
|
120 |
+
|
121 |
+
|
122 |
+
class SequenceEncoderMLP(nn.Module):
|
123 |
+
"""MLP implementation.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
input_dim : dimension of the input variable
|
127 |
+
h_dim : dimension of a hidden layer of MLP
|
128 |
+
num_layers: number of layers to use in the MLP
|
129 |
+
sequence_length: dimension of the input sequence
|
130 |
+
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
input_dim: int,
|
136 |
+
h_dim: int,
|
137 |
+
num_layers: int,
|
138 |
+
sequence_length: int,
|
139 |
+
is_mlp_residual: bool,
|
140 |
+
) -> None:
|
141 |
+
super().__init__()
|
142 |
+
self._mlp = MLP(
|
143 |
+
input_dim * sequence_length, h_dim, h_dim, num_layers, is_mlp_residual
|
144 |
+
)
|
145 |
+
|
146 |
+
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
|
147 |
+
"""Forward function for MapEncoder
|
148 |
+
|
149 |
+
Args:
|
150 |
+
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
|
151 |
+
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
torch.Tensor: (batch_size, num_objects, output_dim) tensor
|
155 |
+
"""
|
156 |
+
|
157 |
+
batch_size, num_objects, _, _ = input.shape
|
158 |
+
input = input * mask_input.unsqueeze(-1)
|
159 |
+
h = rearrange(input, "b o s f -> (b o) (s f)")
|
160 |
+
mask_input = rearrange(mask_input, "b o s -> (b o) s")
|
161 |
+
if h.shape[-1] == 0:
|
162 |
+
h = h.view(batch_size, 0, h.shape[0])
|
163 |
+
else:
|
164 |
+
h = self._mlp(h)
|
165 |
+
h = rearrange(h, "(b o) f -> b o f", b=batch_size, o=num_objects)
|
166 |
+
|
167 |
+
return h
|
168 |
+
|
169 |
+
|
170 |
+
class SequenceDecoderLSTM(nn.Module):
|
171 |
+
"""A one to many LSTM implementation with one layer.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
h_dim : dimension of a hidden layer
|
175 |
+
"""
|
176 |
+
|
177 |
+
def __init__(self, h_dim: int) -> None:
|
178 |
+
super().__init__()
|
179 |
+
self._group_objects = Rearrange("b o f -> (b o) f")
|
180 |
+
self._lstm = nn.LSTM(input_size=h_dim, hidden_size=h_dim)
|
181 |
+
self._out_layer = nn.Linear(in_features=h_dim, out_features=h_dim)
|
182 |
+
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
|
183 |
+
|
184 |
+
def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor:
|
185 |
+
"""Forward function for MapEncoder
|
186 |
+
|
187 |
+
Args:
|
188 |
+
input (torch.Tensor): (batch_size, num_objects, input_dim) tensor
|
189 |
+
sequence_length: output sequence length to create
|
190 |
+
Returns:
|
191 |
+
torch.Tensor: (batch_size, num_objects, output_dim) tensor
|
192 |
+
"""
|
193 |
+
|
194 |
+
batch_size, num_objects, _ = input.shape
|
195 |
+
|
196 |
+
h = repeat(input, "b o f -> one (b o) f", one=1).contiguous()
|
197 |
+
c = repeat(
|
198 |
+
self.c0, "one f -> one size f", size=batch_size * num_objects
|
199 |
+
).contiguous()
|
200 |
+
seq_h = repeat(h, "one b f -> (one t) b f", t=sequence_length).contiguous()
|
201 |
+
h, (_, _) = self._lstm(seq_h, (h, c))
|
202 |
+
h = rearrange(h, "t (b o) f -> b o t f", b=batch_size, o=num_objects)
|
203 |
+
return self._out_layer(h)
|
204 |
+
|
205 |
+
|
206 |
+
class SequenceDecoderMLP(nn.Module):
|
207 |
+
"""A one to many MLP implementation.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
h_dim : dimension of a hidden layer
|
211 |
+
num_layers: number of layers to use in the MLP
|
212 |
+
sequence_length: output sequence length to return
|
213 |
+
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
|
214 |
+
"""
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self, h_dim: int, num_layers: int, sequence_length: int, is_mlp_residual: bool
|
218 |
+
) -> None:
|
219 |
+
super().__init__()
|
220 |
+
self._mlp = MLP(
|
221 |
+
h_dim, h_dim * sequence_length, h_dim, num_layers, is_mlp_residual
|
222 |
+
)
|
223 |
+
|
224 |
+
def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor:
|
225 |
+
"""Forward function for MapEncoder
|
226 |
+
|
227 |
+
Args:
|
228 |
+
input (torch.Tensor): (batch_size, num_objects, input_dim) tensor
|
229 |
+
sequence_length: output sequence length to create
|
230 |
+
Returns:
|
231 |
+
torch.Tensor: (batch_size, num_objects, output_dim) tensor
|
232 |
+
"""
|
233 |
+
|
234 |
+
batch_size, num_objects, _ = input.shape
|
235 |
+
|
236 |
+
h = rearrange(input, "b o f -> (b o) f")
|
237 |
+
h = self._mlp(h)
|
238 |
+
h = rearrange(
|
239 |
+
h, "(b o) (s f) -> b o s f", b=batch_size, o=num_objects, s=sequence_length
|
240 |
+
)
|
241 |
+
return h
|
242 |
+
|
243 |
+
|
244 |
+
class AttentionBlock(nn.Module):
|
245 |
+
"""Block performing agent-map cross attention->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm
|
246 |
+
Args:
|
247 |
+
hidden_dim: feature dimension
|
248 |
+
num_attention_heads: number of attention heads to use
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(self, hidden_dim: int, num_attention_heads: int):
|
252 |
+
super().__init__()
|
253 |
+
self._num_attention_heads = num_attention_heads
|
254 |
+
self._agent_map_attention = MultiHeadAttention(
|
255 |
+
hidden_dim, num_attention_heads, hidden_dim, hidden_dim
|
256 |
+
)
|
257 |
+
self._lin1 = nn.Linear(hidden_dim, hidden_dim)
|
258 |
+
self._layer_norm1 = nn.LayerNorm(hidden_dim)
|
259 |
+
self._agent_agent_attention = MultiHeadAttention(
|
260 |
+
hidden_dim, num_attention_heads, hidden_dim, hidden_dim
|
261 |
+
)
|
262 |
+
self._lin2 = nn.Linear(hidden_dim, hidden_dim)
|
263 |
+
self._layer_norm2 = nn.LayerNorm(hidden_dim)
|
264 |
+
self._activation = nn.ReLU()
|
265 |
+
|
266 |
+
def forward(
|
267 |
+
self,
|
268 |
+
encoded_agents: torch.Tensor,
|
269 |
+
mask_agents: torch.Tensor,
|
270 |
+
encoded_absolute_agents: torch.Tensor,
|
271 |
+
encoded_map: torch.Tensor,
|
272 |
+
mask_map: torch.Tensor,
|
273 |
+
) -> torch.Tensor:
|
274 |
+
"""Forward function of the block, returning only the output (no attention matrix)
|
275 |
+
Args:
|
276 |
+
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
|
277 |
+
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
|
278 |
+
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
279 |
+
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
|
280 |
+
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
|
281 |
+
"""
|
282 |
+
|
283 |
+
# Check if map_info is available. If not, don't compute cross-attention with it
|
284 |
+
if mask_map.any():
|
285 |
+
mask_agent_map = torch.einsum("ba,bo->bao", mask_agents, mask_map)
|
286 |
+
h, _ = self._agent_map_attention(
|
287 |
+
encoded_agents + encoded_absolute_agents,
|
288 |
+
encoded_map,
|
289 |
+
encoded_map,
|
290 |
+
mask=mask_agent_map,
|
291 |
+
)
|
292 |
+
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
|
293 |
+
h = torch.sigmoid(self._lin1(h))
|
294 |
+
h = self._layer_norm1(encoded_agents + h)
|
295 |
+
else:
|
296 |
+
h = self._layer_norm1(encoded_agents)
|
297 |
+
|
298 |
+
h_res = h.clone()
|
299 |
+
agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents)
|
300 |
+
h = h + encoded_absolute_agents
|
301 |
+
h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask)
|
302 |
+
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
|
303 |
+
h = self._activation(self._lin2(h))
|
304 |
+
h = self._layer_norm2(h_res + h)
|
305 |
+
return h
|
306 |
+
|
307 |
+
|
308 |
+
class CG_block(nn.Module):
|
309 |
+
"""Block performing context gating agent-map
|
310 |
+
Args:
|
311 |
+
hidden_dim: feature dimension
|
312 |
+
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
|
313 |
+
num_layers: number of layers to use in the MLP for context encoding
|
314 |
+
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
hidden_dim: int,
|
320 |
+
dim_expansion: int,
|
321 |
+
num_layers: int,
|
322 |
+
is_mlp_residual: bool,
|
323 |
+
):
|
324 |
+
super().__init__()
|
325 |
+
self._agent_map = ContextGating(
|
326 |
+
hidden_dim,
|
327 |
+
hidden_dim * dim_expansion,
|
328 |
+
num_layers=num_layers,
|
329 |
+
is_mlp_residual=is_mlp_residual,
|
330 |
+
)
|
331 |
+
self._lin1 = nn.Linear(hidden_dim, hidden_dim)
|
332 |
+
self._layer_norm1 = nn.LayerNorm(hidden_dim)
|
333 |
+
self._agent_agent = ContextGating(
|
334 |
+
hidden_dim, hidden_dim * dim_expansion, num_layers, is_mlp_residual
|
335 |
+
)
|
336 |
+
self._lin2 = nn.Linear(hidden_dim, hidden_dim)
|
337 |
+
self._activation = nn.ReLU()
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
encoded_agents: torch.Tensor,
|
342 |
+
mask_agents: torch.Tensor,
|
343 |
+
encoded_absolute_agents: torch.Tensor,
|
344 |
+
encoded_map: torch.Tensor,
|
345 |
+
mask_map: torch.Tensor,
|
346 |
+
global_context: torch.Tensor,
|
347 |
+
) -> torch.Tensor:
|
348 |
+
"""Forward function of the block, returning the output and global context
|
349 |
+
Args:
|
350 |
+
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
|
351 |
+
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
|
352 |
+
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
353 |
+
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
|
354 |
+
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
|
355 |
+
global_context: (batch_size, dim_context) tensor representing the global context
|
356 |
+
"""
|
357 |
+
|
358 |
+
# Check if map_info is available. If not, don't compute cross-interaction with it
|
359 |
+
if mask_map.any():
|
360 |
+
s, global_context = self._agent_map(
|
361 |
+
encoded_agents + encoded_absolute_agents, encoded_map, global_context
|
362 |
+
)
|
363 |
+
s = s * mask_agents.unsqueeze(-1)
|
364 |
+
s = self._activation(self._lin1(s))
|
365 |
+
s = self._layer_norm1(encoded_agents + s)
|
366 |
+
else:
|
367 |
+
s = self._layer_norm1(encoded_agents)
|
368 |
+
|
369 |
+
s = s + encoded_absolute_agents
|
370 |
+
s, global_context = self._agent_agent(s, s, global_context)
|
371 |
+
s = s * mask_agents.unsqueeze(-1)
|
372 |
+
s = self._lin2(s)
|
373 |
+
return s, global_context
|
374 |
+
|
375 |
+
|
376 |
+
class HybridBlock(nn.Module):
|
377 |
+
"""Block performing agent-map cross context_gating->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm
|
378 |
+
Args:
|
379 |
+
hidden_dim: feature dimension
|
380 |
+
num_attention_heads: number of attention heads to use
|
381 |
+
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
|
382 |
+
num_layers: number of layers to use in the MLP for context encoding
|
383 |
+
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
|
384 |
+
"""
|
385 |
+
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
hidden_dim: int,
|
389 |
+
num_attention_heads: int,
|
390 |
+
dim_expansion: int,
|
391 |
+
num_layers: int,
|
392 |
+
is_mlp_residual: bool,
|
393 |
+
):
|
394 |
+
super().__init__()
|
395 |
+
self._num_attention_heads = num_attention_heads
|
396 |
+
self._agent_map_cg = ContextGating(
|
397 |
+
hidden_dim,
|
398 |
+
hidden_dim * dim_expansion,
|
399 |
+
num_layers=num_layers,
|
400 |
+
is_mlp_residual=is_mlp_residual,
|
401 |
+
)
|
402 |
+
self._lin1 = nn.Linear(hidden_dim, hidden_dim)
|
403 |
+
self._layer_norm1 = nn.LayerNorm(hidden_dim)
|
404 |
+
self._agent_agent_attention = MultiHeadAttention(
|
405 |
+
hidden_dim, num_attention_heads, hidden_dim, hidden_dim
|
406 |
+
)
|
407 |
+
self._lin2 = nn.Linear(hidden_dim, hidden_dim)
|
408 |
+
self._layer_norm2 = nn.LayerNorm(hidden_dim)
|
409 |
+
self._activation = nn.ReLU()
|
410 |
+
|
411 |
+
def forward(
|
412 |
+
self,
|
413 |
+
encoded_agents: torch.Tensor,
|
414 |
+
mask_agents: torch.Tensor,
|
415 |
+
encoded_absolute_agents: torch.Tensor,
|
416 |
+
encoded_map: torch.Tensor,
|
417 |
+
mask_map: torch.Tensor,
|
418 |
+
global_context: torch.Tensor,
|
419 |
+
) -> torch.Tensor:
|
420 |
+
"""Forward function of the block, returning the output and the context (no attention matrix)
|
421 |
+
Args:
|
422 |
+
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
|
423 |
+
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
|
424 |
+
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
425 |
+
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
|
426 |
+
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
|
427 |
+
global_context: (batch_size, dim_context) tensor representing the global context
|
428 |
+
"""
|
429 |
+
|
430 |
+
# Check if map_info is available. If not, don't compute cross-context gating with it
|
431 |
+
if mask_map.any():
|
432 |
+
# mask_agent_map = torch.logical_not(
|
433 |
+
# torch.einsum("ba,bo->bao", mask_agents, mask_map)
|
434 |
+
# )
|
435 |
+
h, global_context = self._agent_map_cg(
|
436 |
+
encoded_agents + encoded_absolute_agents, encoded_map, global_context
|
437 |
+
)
|
438 |
+
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
|
439 |
+
h = self._activation(self._lin1(h))
|
440 |
+
h = self._layer_norm1(encoded_agents + h)
|
441 |
+
else:
|
442 |
+
h = self._layer_norm1(encoded_agents)
|
443 |
+
|
444 |
+
h_res = h.clone()
|
445 |
+
agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents)
|
446 |
+
h = h + encoded_absolute_agents
|
447 |
+
h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask)
|
448 |
+
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
|
449 |
+
h = self._activation(self._lin2(h))
|
450 |
+
h = self._layer_norm2(h_res + h)
|
451 |
+
return h, global_context
|
452 |
+
|
453 |
+
|
454 |
+
class MCG(nn.Module):
|
455 |
+
"""Multiple context encoding blocks
|
456 |
+
Args:
|
457 |
+
hidden_dim: feature dimension
|
458 |
+
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
|
459 |
+
num_layers: number of layers to use in the MLP for context encoding
|
460 |
+
num_blocks: number of successive context encoding blocks to use in the module
|
461 |
+
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __init__(
|
465 |
+
self,
|
466 |
+
hidden_dim: int,
|
467 |
+
dim_expansion: int,
|
468 |
+
num_layers: int,
|
469 |
+
num_blocks: int,
|
470 |
+
is_mlp_residual: bool,
|
471 |
+
):
|
472 |
+
super().__init__()
|
473 |
+
self.initial_global_context = nn.parameter.Parameter(
|
474 |
+
torch.ones(1, hidden_dim * dim_expansion)
|
475 |
+
)
|
476 |
+
list_cg = []
|
477 |
+
for i in range(num_blocks):
|
478 |
+
list_cg.append(
|
479 |
+
CG_block(hidden_dim, dim_expansion, num_layers, is_mlp_residual)
|
480 |
+
)
|
481 |
+
self.mcg = nn.ModuleList(list_cg)
|
482 |
+
|
483 |
+
def forward(
|
484 |
+
self,
|
485 |
+
encoded_agents: torch.Tensor,
|
486 |
+
mask_agents: torch.Tensor,
|
487 |
+
encoded_absolute_agents: torch.Tensor,
|
488 |
+
encoded_map: torch.Tensor,
|
489 |
+
mask_map: torch.Tensor,
|
490 |
+
) -> torch.Tensor:
|
491 |
+
"""Forward function of the block, returning only the output (no context)
|
492 |
+
Args:
|
493 |
+
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
|
494 |
+
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
|
495 |
+
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
496 |
+
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
|
497 |
+
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
|
498 |
+
"""
|
499 |
+
s = encoded_agents
|
500 |
+
c = self.initial_global_context
|
501 |
+
sum_s = s
|
502 |
+
sum_c = c
|
503 |
+
for i, cg in enumerate(self.mcg):
|
504 |
+
s_new, c_new = cg(
|
505 |
+
s, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c
|
506 |
+
)
|
507 |
+
sum_s = sum_s + s_new
|
508 |
+
sum_c = sum_c + c_new
|
509 |
+
s = (sum_s / (i + 2)).clone()
|
510 |
+
c = (sum_c / (i + 2)).clone()
|
511 |
+
return s
|
512 |
+
|
513 |
+
|
514 |
+
class MAB(nn.Module):
|
515 |
+
"""Multiple Attention Blocks
|
516 |
+
Args:
|
517 |
+
hidden_dim: feature dimension
|
518 |
+
num_attention_heads: number of attention heads to use
|
519 |
+
num_blocks: number of successive blocks to use in the module.
|
520 |
+
"""
|
521 |
+
|
522 |
+
def __init__(
|
523 |
+
self,
|
524 |
+
hidden_dim: int,
|
525 |
+
num_attention_heads: int,
|
526 |
+
num_blocks: int,
|
527 |
+
):
|
528 |
+
super().__init__()
|
529 |
+
list_attention = []
|
530 |
+
for i in range(num_blocks):
|
531 |
+
list_attention.append(AttentionBlock(hidden_dim, num_attention_heads))
|
532 |
+
self.attention_blocks = nn.ModuleList(list_attention)
|
533 |
+
|
534 |
+
def forward(
|
535 |
+
self,
|
536 |
+
encoded_agents: torch.Tensor,
|
537 |
+
mask_agents: torch.Tensor,
|
538 |
+
encoded_absolute_agents: torch.Tensor,
|
539 |
+
encoded_map: torch.Tensor,
|
540 |
+
mask_map: torch.Tensor,
|
541 |
+
) -> torch.Tensor:
|
542 |
+
"""Forward function of the block, returning only the output (no attention matrix)
|
543 |
+
Args:
|
544 |
+
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
|
545 |
+
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
|
546 |
+
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
547 |
+
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
|
548 |
+
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
|
549 |
+
"""
|
550 |
+
h = encoded_agents
|
551 |
+
sum_h = h
|
552 |
+
for i, attention in enumerate(self.attention_blocks):
|
553 |
+
h_new = attention(
|
554 |
+
h, mask_agents, encoded_absolute_agents, encoded_map, mask_map
|
555 |
+
)
|
556 |
+
sum_h = sum_h + h_new
|
557 |
+
h = (sum_h / (i + 2)).clone()
|
558 |
+
return h
|
559 |
+
|
560 |
+
|
561 |
+
class MHB(nn.Module):
|
562 |
+
"""Multiple Hybrid Blocks
|
563 |
+
Args:
|
564 |
+
hidden_dim: feature dimension
|
565 |
+
num_attention_heads: number of attention heads to use
|
566 |
+
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
|
567 |
+
num_layers: number of layers to use in the MLP for context encoding
|
568 |
+
num_blocks: number of successive blocks to use in the module.
|
569 |
+
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
|
570 |
+
"""
|
571 |
+
|
572 |
+
def __init__(
|
573 |
+
self,
|
574 |
+
hidden_dim: int,
|
575 |
+
num_attention_heads: int,
|
576 |
+
dim_expansion: int,
|
577 |
+
num_layers: int,
|
578 |
+
num_blocks: int,
|
579 |
+
is_mlp_residual: bool,
|
580 |
+
):
|
581 |
+
super().__init__()
|
582 |
+
self.initial_global_context = nn.parameter.Parameter(
|
583 |
+
torch.ones(1, hidden_dim * dim_expansion)
|
584 |
+
)
|
585 |
+
list_hb = []
|
586 |
+
for i in range(num_blocks):
|
587 |
+
list_hb.append(
|
588 |
+
HybridBlock(
|
589 |
+
hidden_dim,
|
590 |
+
num_attention_heads,
|
591 |
+
dim_expansion,
|
592 |
+
num_layers,
|
593 |
+
is_mlp_residual,
|
594 |
+
)
|
595 |
+
)
|
596 |
+
self.hybrid_blocks = nn.ModuleList(list_hb)
|
597 |
+
|
598 |
+
def forward(
|
599 |
+
self,
|
600 |
+
encoded_agents: torch.Tensor,
|
601 |
+
mask_agents: torch.Tensor,
|
602 |
+
encoded_absolute_agents: torch.Tensor,
|
603 |
+
encoded_map: torch.Tensor,
|
604 |
+
mask_map: torch.Tensor,
|
605 |
+
) -> torch.Tensor:
|
606 |
+
"""Forward function of the block, returning only the output (no attention matrix nor context)
|
607 |
+
Args:
|
608 |
+
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
|
609 |
+
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
|
610 |
+
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
|
611 |
+
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
|
612 |
+
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
|
613 |
+
"""
|
614 |
+
sum_h = encoded_agents
|
615 |
+
sum_c = self.initial_global_context
|
616 |
+
h = encoded_agents
|
617 |
+
c = self.initial_global_context
|
618 |
+
for i, hb in enumerate(self.hybrid_blocks):
|
619 |
+
h_new, c_new = hb(
|
620 |
+
h, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c
|
621 |
+
)
|
622 |
+
sum_h = sum_h + h_new
|
623 |
+
sum_c = sum_c + c_new
|
624 |
+
h = (sum_h / (i + 2)).clone()
|
625 |
+
c = (sum_c / (i + 2)).clone()
|
626 |
+
return h
|
risk_biased/mpc_planner/__init__.py
ADDED
File without changes
|
risk_biased/mpc_planner/dynamics.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from risk_biased.utils.planner_utils import AbstractState, to_state
|
4 |
+
|
5 |
+
|
6 |
+
class PositionVelocityDoubleIntegrator:
|
7 |
+
"""Deterministic discrete-time double-integrator dynamics, where state is
|
8 |
+
[position_x_m, position_y_m, velocity_x_m_s velocity_y_m_s] and control is
|
9 |
+
[acceleration_x_m_s2, acceleration_y_m_s2].
|
10 |
+
|
11 |
+
Args:
|
12 |
+
dt: time differential between two discrete timesteps in seconds
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, dt: float):
|
16 |
+
self.dt = dt
|
17 |
+
self.control_dim = 2
|
18 |
+
|
19 |
+
def simulate(
|
20 |
+
self,
|
21 |
+
state_init: AbstractState,
|
22 |
+
control_input: torch.Tensor,
|
23 |
+
) -> AbstractState:
|
24 |
+
"""Euler-integrate dynamics from the initial position and the initial velocity given
|
25 |
+
an acceleration input
|
26 |
+
|
27 |
+
Args:
|
28 |
+
state_init: initial Markov state of the system
|
29 |
+
control_input: (num_agents, num_steps_future, 2) tensor of acceleration input
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
(num_agents, num_steps_future, 5) tensor of simulated future Markov state
|
33 |
+
sequence
|
34 |
+
"""
|
35 |
+
position_init, velocity_init = state_init.position, state_init.velocity
|
36 |
+
|
37 |
+
assert (
|
38 |
+
control_input.shape[-1] == self.control_dim
|
39 |
+
), "invalid control input dimension"
|
40 |
+
|
41 |
+
velocity_future = velocity_init + self.dt * torch.cumsum(control_input, dim=-2)
|
42 |
+
|
43 |
+
position_future = position_init + self.dt * torch.cumsum(
|
44 |
+
velocity_future, dim=-2
|
45 |
+
)
|
46 |
+
state_future = to_state(
|
47 |
+
torch.cat((position_future, velocity_future), dim=-1), self.dt
|
48 |
+
)
|
49 |
+
return state_future
|
risk_biased/mpc_planner/planner.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Callable, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmcv import Config
|
6 |
+
|
7 |
+
from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor
|
8 |
+
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
|
9 |
+
from risk_biased.mpc_planner.planner_cost import TrackingCostParams
|
10 |
+
from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams
|
11 |
+
from risk_biased.mpc_planner.planner_cost import TrackingCost
|
12 |
+
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
|
13 |
+
from risk_biased.utils.planner_utils import AbstractState, to_state
|
14 |
+
from risk_biased.utils.risk import get_risk_estimator
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class MPCPlannerParams:
|
19 |
+
"""Dataclass for MPC-Planner Parameters
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dt_s: discrete time interval in seconds that is used for planning
|
23 |
+
num_steps: number of time steps for which history of ego's and the other actor's
|
24 |
+
trajectories are stored
|
25 |
+
num_steps_future: number of time steps into the future for which ego's and the other actor's
|
26 |
+
trajectories are considered
|
27 |
+
acceleration_std_x_m_s2: Acceleration noise standard deviation (m/s^2) in x-direction that
|
28 |
+
is used to initialize the Cross Entropy solver
|
29 |
+
acceleration_std_y_m_s2: Acceleration noise standard deviation (m/s^2) in y-direction that
|
30 |
+
is used to initialize the Cross Entropy solver
|
31 |
+
risk_estimator_params: parameters for the Monte Carlo risk estimator used in the planner for
|
32 |
+
ego's control optimization
|
33 |
+
solver_params: parameters for the CrossEntropySolver
|
34 |
+
tracking_cost_params: parameters for the TrackingCost
|
35 |
+
ttc_cost_params: parameters for the TTCCost (i.e., collision cost between ego and the other
|
36 |
+
actor)
|
37 |
+
"""
|
38 |
+
|
39 |
+
dt: float
|
40 |
+
num_steps: int
|
41 |
+
num_steps_future: int
|
42 |
+
acceleration_std_x_m_s2: float
|
43 |
+
acceleration_std_y_m_s2: float
|
44 |
+
|
45 |
+
risk_estimator_params: dict
|
46 |
+
solver_params: CrossEntropySolverParams
|
47 |
+
tracking_cost_params: TrackingCostParams
|
48 |
+
ttc_cost_params: TTCCostParams
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def from_config(cfg: Config):
|
52 |
+
return MPCPlannerParams(
|
53 |
+
cfg.dt,
|
54 |
+
cfg.num_steps,
|
55 |
+
cfg.num_steps_future,
|
56 |
+
cfg.acceleration_std_x_m_s2,
|
57 |
+
cfg.acceleration_std_y_m_s2,
|
58 |
+
cfg.risk_estimator,
|
59 |
+
CrossEntropySolverParams.from_config(cfg),
|
60 |
+
TrackingCostParams.from_config(cfg),
|
61 |
+
TTCCostParams.from_config(cfg),
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
class MPCPlanner:
|
66 |
+
"""MPC Planner with a Cross Entropy solver
|
67 |
+
|
68 |
+
Args:
|
69 |
+
params: MPCPlannerParams object
|
70 |
+
predictor: LitTrajectoryPredictor object
|
71 |
+
normalizer: function that takes in an unnormalized trajectory and that outputs the
|
72 |
+
normalized trajectory and the offset in this order
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
params: MPCPlannerParams,
|
78 |
+
predictor: LitTrajectoryPredictor,
|
79 |
+
normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
|
80 |
+
) -> None:
|
81 |
+
|
82 |
+
self.params = params
|
83 |
+
self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt)
|
84 |
+
self.control_input_mean_init = torch.zeros(
|
85 |
+
1, params.num_steps_future, self.dynamics_model.control_dim
|
86 |
+
)
|
87 |
+
self.control_input_std_init = torch.Tensor(
|
88 |
+
[
|
89 |
+
params.acceleration_std_x_m_s2,
|
90 |
+
params.acceleration_std_y_m_s2,
|
91 |
+
]
|
92 |
+
).expand_as(self.control_input_mean_init)
|
93 |
+
self.solver = CrossEntropySolver(
|
94 |
+
params=params.solver_params,
|
95 |
+
dynamics_model=self.dynamics_model,
|
96 |
+
control_input_mean=self.control_input_mean_init,
|
97 |
+
control_input_std=self.control_input_std_init,
|
98 |
+
tracking_cost_function=TrackingCost(params.tracking_cost_params),
|
99 |
+
interaction_cost_function=TTCCostTorch(params.ttc_cost_params),
|
100 |
+
risk_estimator=get_risk_estimator(params.risk_estimator_params),
|
101 |
+
)
|
102 |
+
self.predictor = predictor
|
103 |
+
self.normalizer = normalizer
|
104 |
+
|
105 |
+
self._ego_state_history = []
|
106 |
+
self._ego_state_target_trajectory = None
|
107 |
+
self._ego_state_planned_trajectory = None
|
108 |
+
|
109 |
+
self._ado_state_history = []
|
110 |
+
self._latest_ado_position_future_samples = None
|
111 |
+
|
112 |
+
def replan(
|
113 |
+
self,
|
114 |
+
current_ado_state: AbstractState,
|
115 |
+
current_ego_state: AbstractState,
|
116 |
+
target_velocity: torch.Tensor,
|
117 |
+
num_prediction_samples: int = 1,
|
118 |
+
risk_level: float = 0.0,
|
119 |
+
resample_prediction: bool = False,
|
120 |
+
risk_in_predictor: bool = False,
|
121 |
+
) -> None:
|
122 |
+
"""Performs re-planning given the current_ado_position, current_ego_state, and
|
123 |
+
target_velocity. Updates ego_state_planned_trajectory. Note that all the information given
|
124 |
+
to the solver.solve(...) is expressed in the ego-centric frame, whose origin is the initial
|
125 |
+
ego position in ego_state_history and the x-direction is parallel to the initial ego
|
126 |
+
velocity.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
current_ado_position: ado state
|
130 |
+
current_ego_state: ego state
|
131 |
+
target_velocity: ((1), 2) tensor
|
132 |
+
num_prediction_samples (optional): number of prediction samples. Defaults to 1.
|
133 |
+
risk_level (optional): a risk-level float for the entire prediction-planning pipeline.
|
134 |
+
If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0.
|
135 |
+
resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy
|
136 |
+
iteration. Defaults to False.
|
137 |
+
risk_in_predictor (optional): If True, risk-biased prediction is used and the solver
|
138 |
+
becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes
|
139 |
+
risk-sensitive. Defaults to False.
|
140 |
+
"""
|
141 |
+
self._update_ado_state_history(current_ado_state)
|
142 |
+
self._update_ego_state_history(current_ego_state)
|
143 |
+
self._update_ego_state_target_trajectory(current_ego_state, target_velocity)
|
144 |
+
if not self.ado_state_history.shape[-1] < self.params.num_steps:
|
145 |
+
self.solver.solve(
|
146 |
+
self.predictor,
|
147 |
+
self._map_to_ego_centric_frame(self.ego_state_history),
|
148 |
+
self._map_to_ego_centric_frame(self._ego_state_target_trajectory),
|
149 |
+
self._map_to_ego_centric_frame(self.ado_state_history),
|
150 |
+
self.normalizer,
|
151 |
+
num_prediction_samples=num_prediction_samples,
|
152 |
+
risk_level=risk_level,
|
153 |
+
resample_prediction=resample_prediction,
|
154 |
+
risk_in_predictor=risk_in_predictor,
|
155 |
+
)
|
156 |
+
ego_state_planned_trajectory_in_ego_frame = self.dynamics_model.simulate(
|
157 |
+
self._map_to_ego_centric_frame(self.ego_state_history[..., -1]),
|
158 |
+
self.solver.control_sequence,
|
159 |
+
)
|
160 |
+
self._ego_state_planned_trajectory = self._map_to_world_frame(
|
161 |
+
ego_state_planned_trajectory_in_ego_frame
|
162 |
+
)
|
163 |
+
latest_ado_position_future_samples_in_ego_frame = (
|
164 |
+
self.solver.fetch_latest_prediction()
|
165 |
+
)
|
166 |
+
if latest_ado_position_future_samples_in_ego_frame is not None:
|
167 |
+
self._latest_ado_position_future_samples = self._map_to_world_frame(
|
168 |
+
latest_ado_position_future_samples_in_ego_frame
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
self._latest_ado_position_future_samples = None
|
172 |
+
|
173 |
+
def get_planned_next_ego_state(self) -> AbstractState:
|
174 |
+
"""Returns the next ego state according to the ego_state_planned_trajectory
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Planned state
|
178 |
+
"""
|
179 |
+
assert (
|
180 |
+
self._ego_state_planned_trajectory is not None
|
181 |
+
), "call self.replan(...) first"
|
182 |
+
return self._ego_state_planned_trajectory[..., 0]
|
183 |
+
|
184 |
+
def reset(self) -> None:
|
185 |
+
"""Resets the planner's internal state. This will fully reset the solver's internal state,
|
186 |
+
including solver.control_input_mean_init and solver.control_input_std_init."""
|
187 |
+
self.solver.control_input_mean_init = (
|
188 |
+
self.control_input_mean_init.detach().clone()
|
189 |
+
)
|
190 |
+
self.solver.control_input_std_init = (
|
191 |
+
self.control_input_std_init.detach().clone()
|
192 |
+
)
|
193 |
+
self.solver.reset()
|
194 |
+
|
195 |
+
self._ego_state_history = []
|
196 |
+
self._ego_state_target_trajectory = None
|
197 |
+
self._ego_state_planned_trajectory = None
|
198 |
+
|
199 |
+
self._ado_state_history = []
|
200 |
+
self._latest_ado_position_future_samples = None
|
201 |
+
|
202 |
+
def fetch_latest_prediction(self) -> torch.Tensor:
|
203 |
+
if self._latest_ado_position_future_samples is not None:
|
204 |
+
return self._latest_ado_position_future_samples
|
205 |
+
else:
|
206 |
+
return None
|
207 |
+
|
208 |
+
@property
|
209 |
+
def ego_state_history(self) -> torch.Tensor:
|
210 |
+
"""Returns ego_state_history as a concatenated tensor
|
211 |
+
Returns:
|
212 |
+
ego_state_history tensor
|
213 |
+
"""
|
214 |
+
assert len(self._ego_state_history) > 0
|
215 |
+
return to_state(
|
216 |
+
torch.stack(
|
217 |
+
[ego_state.get_states(4) for ego_state in self._ego_state_history],
|
218 |
+
dim=-2,
|
219 |
+
),
|
220 |
+
self.params.dt,
|
221 |
+
)
|
222 |
+
|
223 |
+
@property
|
224 |
+
def ado_state_history(self) -> torch.Tensor:
|
225 |
+
"""Returns ado_position_history as a concatenated tensor
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
ado_position_history tensor
|
229 |
+
"""
|
230 |
+
assert len(self._ado_state_history) > 0
|
231 |
+
return to_state(
|
232 |
+
torch.stack(
|
233 |
+
[ado_state.get_states(4) for ado_state in self._ado_state_history],
|
234 |
+
dim=-2,
|
235 |
+
),
|
236 |
+
self.params.dt,
|
237 |
+
)
|
238 |
+
|
239 |
+
def _update_ego_state_history(self, current_ego_state: AbstractState) -> None:
|
240 |
+
"""Updates ego_state_history with the current_ego_state
|
241 |
+
|
242 |
+
Args:
|
243 |
+
current_ego_state: (1, state_dim) tensor
|
244 |
+
"""
|
245 |
+
|
246 |
+
if len(self._ego_state_history) >= self.params.num_steps:
|
247 |
+
self._ego_state_history = self._ego_state_history[1:]
|
248 |
+
self._ego_state_history.append(current_ego_state)
|
249 |
+
assert len(self._ego_state_history) <= self.params.num_steps
|
250 |
+
|
251 |
+
def _update_ado_state_history(self, current_ado_state: AbstractState) -> None:
|
252 |
+
"""Updates ego_state_history with the current_ado_position
|
253 |
+
|
254 |
+
Args:
|
255 |
+
current_ado_state states of the current non-ego vehicles
|
256 |
+
"""
|
257 |
+
|
258 |
+
if len(self._ado_state_history) >= self.params.num_steps:
|
259 |
+
self._ado_state_history = self._ado_state_history[1:]
|
260 |
+
self._ado_state_history.append(current_ado_state)
|
261 |
+
assert len(self._ado_state_history) <= self.params.num_steps
|
262 |
+
|
263 |
+
def _update_ego_state_target_trajectory(
|
264 |
+
self, current_ego_state: AbstractState, target_velocity: torch.Tensor
|
265 |
+
) -> None:
|
266 |
+
"""Updates ego_state_target_trajectory based on the current_ego_state and the target_velocity
|
267 |
+
|
268 |
+
Args:
|
269 |
+
current_ego_state: state
|
270 |
+
target_velocity: (1, 2) tensor
|
271 |
+
"""
|
272 |
+
|
273 |
+
target_displacement = self.params.dt * target_velocity
|
274 |
+
target_position_list = [current_ego_state.position]
|
275 |
+
for time_idx in range(self.params.num_steps_future):
|
276 |
+
target_position_list.append(target_position_list[-1] + target_displacement)
|
277 |
+
target_position_list = target_position_list[1:]
|
278 |
+
target_position = torch.cat(target_position_list, dim=-2)
|
279 |
+
target_state = to_state(
|
280 |
+
torch.cat(
|
281 |
+
(target_position, target_velocity.expand_as(target_position)), dim=-1
|
282 |
+
),
|
283 |
+
self.params.dt,
|
284 |
+
)
|
285 |
+
self._ego_state_target_trajectory = target_state
|
286 |
+
|
287 |
+
def _map_to_ego_centric_frame(
|
288 |
+
self, trajectory_in_world_frame: AbstractState
|
289 |
+
) -> torch.Tensor:
|
290 |
+
"""Maps trajectory epxressed in the world frame to the ego-centric frame, whose origin is
|
291 |
+
the initial ego position in ego_state_history and the x-direction is parallel to the initial
|
292 |
+
ego velocity
|
293 |
+
|
294 |
+
Args:
|
295 |
+
trajectory: sequence of states
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
trajectory mapped to the ego-centric frame
|
299 |
+
"""
|
300 |
+
# If trajectory_in_world_frame is of shape (..., state_dim) then use the associated
|
301 |
+
# dynamics model in translate_position and rotate_angle. Otherwise assume that th
|
302 |
+
# trajectory is in the 2D position space.
|
303 |
+
|
304 |
+
ego_pos_init = self.ego_state_history.position[..., -1, :]
|
305 |
+
ego_vel_init = self.ego_state_history.velocity[..., -1, :]
|
306 |
+
ego_rot_init = torch.atan2(ego_vel_init[..., 1], ego_vel_init[..., 0])
|
307 |
+
trajectory_in_ego_frame = trajectory_in_world_frame.translate(
|
308 |
+
-ego_pos_init
|
309 |
+
).rotate(-ego_rot_init)
|
310 |
+
return trajectory_in_ego_frame
|
311 |
+
|
312 |
+
def _map_to_world_frame(
|
313 |
+
self, trajectory_in_ego_frame: torch.Tensor
|
314 |
+
) -> torch.Tensor:
|
315 |
+
"""Maps trajectory epxressed in the ego-centric frame to the world frame
|
316 |
+
|
317 |
+
Args:
|
318 |
+
trajectory_in_ego_frame: (..., 2) position trajectory or (..., markov_state_dim) state
|
319 |
+
trajectory expressed in the ego-centric frame, whose origin is the initial ego
|
320 |
+
position in ego_state_history and the x-direction is parallel to the initial ego
|
321 |
+
velocity
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
trajectory mapped to the world frame
|
325 |
+
"""
|
326 |
+
# state starts with x, y, angle
|
327 |
+
ego_pos_init = self.ego_state_history.position[..., -1, :]
|
328 |
+
ego_rot_init = self.ego_state_history.angle[..., -1, :]
|
329 |
+
trajectory_in_world_frame = trajectory_in_ego_frame.rotate(
|
330 |
+
ego_rot_init
|
331 |
+
).translate(ego_pos_init)
|
332 |
+
return trajectory_in_world_frame
|
risk_biased/mpc_planner/planner_cost.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmcv import Config
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class TrackingCostParams:
|
10 |
+
scale_longitudinal: float
|
11 |
+
scale_lateral: float
|
12 |
+
reduce: str
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def from_config(cfg: Config):
|
16 |
+
return TrackingCostParams(
|
17 |
+
scale_longitudinal=cfg.tracking_cost_scale_longitudinal,
|
18 |
+
scale_lateral=cfg.tracking_cost_scale_lateral,
|
19 |
+
reduce=cfg.tracking_cost_reduce,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class TrackingCost:
|
24 |
+
"""Quadratic Trajectory Tracking Cost
|
25 |
+
|
26 |
+
Args:
|
27 |
+
params: tracking cost parameters
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, params: TrackingCostParams) -> None:
|
31 |
+
self.scale_longitudinal = params.scale_longitudinal
|
32 |
+
self.scale_lateral = params.scale_lateral
|
33 |
+
assert params.reduce in [
|
34 |
+
"min",
|
35 |
+
"max",
|
36 |
+
"mean",
|
37 |
+
"now",
|
38 |
+
"final",
|
39 |
+
], "unsupported reduce type"
|
40 |
+
self._reduce_fun_name = params.reduce
|
41 |
+
|
42 |
+
def __call__(
|
43 |
+
self,
|
44 |
+
ego_position_trajectory: torch.Tensor,
|
45 |
+
target_position_trajectory: torch.Tensor,
|
46 |
+
target_velocity_trajectory: torch.Tensor,
|
47 |
+
) -> torch.Tensor:
|
48 |
+
"""Computes quadratic tracking cost
|
49 |
+
|
50 |
+
Args:
|
51 |
+
ego_position_trajectory: (some_shape, num_some_steps, 2) tensor of ego
|
52 |
+
position trajectory
|
53 |
+
target_position_trajectory: (some_shape, num_some_steps, 2) tensor of
|
54 |
+
ego target position trajectory
|
55 |
+
target_velocity_trajectory: (some_shape, num_some_steps, 2) tensor of
|
56 |
+
ego target velocity trajectory
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
(some_shape) cost
|
60 |
+
"""
|
61 |
+
cost_matrix = self._get_quadratic_cost_matrix(target_velocity_trajectory)
|
62 |
+
cost = (
|
63 |
+
(
|
64 |
+
(ego_position_trajectory - target_position_trajectory).unsqueeze(-2)
|
65 |
+
@ cost_matrix
|
66 |
+
@ (ego_position_trajectory - target_position_trajectory).unsqueeze(-1)
|
67 |
+
)
|
68 |
+
.squeeze(-1)
|
69 |
+
.squeeze(-1)
|
70 |
+
)
|
71 |
+
return self._reduce(cost, dim=-1)
|
72 |
+
|
73 |
+
def _reduce(self, cost: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
|
74 |
+
"""Reduces the cost tensor based on self._reduce_fun_name
|
75 |
+
|
76 |
+
Args:
|
77 |
+
cost: cost tensor of some shape where the last dimension represents time
|
78 |
+
dim (optional): tensor dimension to be reduced. Defaults to None.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
reduced cost tensor
|
82 |
+
"""
|
83 |
+
if self._reduce_fun_name == "min":
|
84 |
+
return torch.min(cost, dim=dim)[0] if dim is not None else torch.min(cost)
|
85 |
+
if self._reduce_fun_name == "max":
|
86 |
+
return torch.max(cost, dim=dim)[0] if dim is not None else torch.max(cost)
|
87 |
+
if self._reduce_fun_name == "mean":
|
88 |
+
return torch.mean(cost, dim=dim) if dim is not None else torch.mean(cost)
|
89 |
+
if self._reduce_fun_name == "now":
|
90 |
+
return cost[..., 0]
|
91 |
+
if self._reduce_fun_name == "final":
|
92 |
+
return cost[..., -1]
|
93 |
+
|
94 |
+
def _get_quadratic_cost_matrix(
|
95 |
+
self, target_velocity_trajectory: torch.Tensor, eps: float = 1e-8
|
96 |
+
) -> torch.Tensor:
|
97 |
+
"""Gets quadratic cost matrix based on target velocity direction per time step.
|
98 |
+
If target velocity is 0 in norm, then all zero tensor is returned for that time step.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
target_velocity_trajectory: (some_shape, num_some_steps, 2) tensor of
|
102 |
+
ego target velocity trajectory
|
103 |
+
eps (optional): small positive number to ensure numerical stability. Defaults to
|
104 |
+
1e-8.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
(some_shape, num_some_steps, 2, 2) quadratic cost matrix
|
108 |
+
"""
|
109 |
+
longitudinal_direction = (
|
110 |
+
target_velocity_trajectory
|
111 |
+
/ (
|
112 |
+
torch.linalg.norm(target_velocity_trajectory, dim=-1).unsqueeze(-1)
|
113 |
+
+ eps
|
114 |
+
)
|
115 |
+
).unsqueeze(-1)
|
116 |
+
rotation_90_deg = torch.Tensor([[[0.0, -1.0], [1.0, 0]]])
|
117 |
+
lateral_direction = rotation_90_deg @ longitudinal_direction
|
118 |
+
orthogonal_matrix = torch.cat(
|
119 |
+
(longitudinal_direction, lateral_direction), dim=-1
|
120 |
+
)
|
121 |
+
eigen_matrix = torch.Tensor(
|
122 |
+
[[[self.scale_longitudinal, 0.0], [0.0, self.scale_lateral]]]
|
123 |
+
)
|
124 |
+
cost_matrix = (
|
125 |
+
orthogonal_matrix @ eigen_matrix @ orthogonal_matrix.transpose(-1, -2)
|
126 |
+
)
|
127 |
+
return cost_matrix
|
risk_biased/mpc_planner/solver.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
from mmcv import Config
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
|
9 |
+
from risk_biased.mpc_planner.planner_cost import TrackingCost
|
10 |
+
from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor
|
11 |
+
from risk_biased.utils.cost import BaseCostTorch
|
12 |
+
from risk_biased.utils.planner_utils import (
|
13 |
+
AbstractState,
|
14 |
+
to_state,
|
15 |
+
evaluate_risk,
|
16 |
+
get_interaction_cost,
|
17 |
+
)
|
18 |
+
from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class CrossEntropySolverParams:
|
23 |
+
"""Dataclass for Cross Entropy Solver Parameters
|
24 |
+
|
25 |
+
Args:
|
26 |
+
num_control_samples: number of Monte Carlo samples for control input
|
27 |
+
num_elite: number of elite samples
|
28 |
+
iter_max: maximum iteration number
|
29 |
+
smoothing_factor: smoothing factor in (0, 1) used to update the mean and the std of the
|
30 |
+
control input distribution for the next iteration. If 0, the updated distribution is
|
31 |
+
independent of the previous iteration. If 1, the updated distribution is the same as the
|
32 |
+
previous iteration.
|
33 |
+
mean_warm_start: internally saves control_input_mean of the last iteration of the current
|
34 |
+
solve, so that control_input_mean will be warm-started in the next solve
|
35 |
+
"""
|
36 |
+
|
37 |
+
num_control_samples: int
|
38 |
+
num_elite: int
|
39 |
+
iter_max: int
|
40 |
+
smoothing_factor: float
|
41 |
+
mean_warm_start: bool
|
42 |
+
dt: float
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def from_config(cfg: Config):
|
46 |
+
return CrossEntropySolverParams(
|
47 |
+
cfg.num_control_samples,
|
48 |
+
cfg.num_elite,
|
49 |
+
cfg.iter_max,
|
50 |
+
cfg.smoothing_factor,
|
51 |
+
cfg.mean_warm_start,
|
52 |
+
cfg.dt,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class CrossEntropySolver:
|
57 |
+
"""Cross Entropy Solver for MPC Planner
|
58 |
+
|
59 |
+
Args:
|
60 |
+
params: CrossEntropySolverParams object
|
61 |
+
dynamics_model: dynamics model for control
|
62 |
+
control_input_mean: (num_agents, num_steps_future, control_dim) tensor of control input mean
|
63 |
+
control_input_std: (num_agents, num_steps_future, control_dim) tensor of control input std
|
64 |
+
tracking_cost_function: deterministic tracking cost that does not involve ado
|
65 |
+
intraction_cost_function: interaction cost function between ego and (stochastic) ado
|
66 |
+
risk_estimator (optional): Monte Carlo risk estimator for risk computation. If None,
|
67 |
+
risk-neutral expecation is used for selectoin of elites. Defaults to None.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
params: CrossEntropySolverParams,
|
73 |
+
dynamics_model: PositionVelocityDoubleIntegrator,
|
74 |
+
control_input_mean: torch.Tensor,
|
75 |
+
control_input_std: torch.Tensor,
|
76 |
+
tracking_cost_function: TrackingCost,
|
77 |
+
interaction_cost_function: BaseCostTorch,
|
78 |
+
risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
|
79 |
+
) -> None:
|
80 |
+
self.params = params
|
81 |
+
|
82 |
+
self.control_input_mean_init = control_input_mean.detach().clone()
|
83 |
+
self.control_input_std_init = control_input_std.detach().clone()
|
84 |
+
assert (
|
85 |
+
self.control_input_mean_init.shape == self.control_input_std_init.shape
|
86 |
+
), "control input mean and std must have the same size"
|
87 |
+
assert (
|
88 |
+
self.control_input_mean_init.shape[-1] == dynamics_model.control_dim
|
89 |
+
), f"control dimension must be {dynamics_model.control_dim}"
|
90 |
+
|
91 |
+
self.dynamics_model = dynamics_model
|
92 |
+
self.tracking_cost = tracking_cost_function
|
93 |
+
self.interaction_cost = interaction_cost_function
|
94 |
+
self.risk_estimator = risk_estimator
|
95 |
+
|
96 |
+
self._iter_current = None
|
97 |
+
self._control_input_mean = None
|
98 |
+
self._control_input_std = None
|
99 |
+
|
100 |
+
self._latest_ado_position_future_samples = None
|
101 |
+
|
102 |
+
self.reset()
|
103 |
+
|
104 |
+
def reset(self) -> None:
|
105 |
+
"""Resets the solver's internal state"""
|
106 |
+
self._iter_current = 0
|
107 |
+
self._control_input_mean = self.control_input_mean_init.clone()
|
108 |
+
self._control_input_std = self.control_input_std_init.clone()
|
109 |
+
self._latest_ado_position_future_samples = None
|
110 |
+
|
111 |
+
def step(
|
112 |
+
self,
|
113 |
+
ego_state_history: AbstractState,
|
114 |
+
ego_state_target_trajectory: AbstractState,
|
115 |
+
ado_state_future_samples: AbstractState,
|
116 |
+
weights: torch.Tensor,
|
117 |
+
verbose: bool = False,
|
118 |
+
risk_level: float = 0.0,
|
119 |
+
) -> Dict:
|
120 |
+
"""Performs one iteration step of the Cross Entropy Method
|
121 |
+
|
122 |
+
Args:
|
123 |
+
ego_state_history: (num_agents, num_steps) ego state history
|
124 |
+
ego_state_target_trajectory: (num_agents, num_steps_future) ego target
|
125 |
+
state trajectory
|
126 |
+
ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
|
127 |
+
predicted ado trajectory samples
|
128 |
+
weights: (num_prediction_samples, num_agents) prediction sample weight
|
129 |
+
verbose (optional): Print progress. Defaults to False.
|
130 |
+
risk_level (optional): a risk-level float for the solver. If 0.0, risk-neutral
|
131 |
+
expectation is used for selection of elites. Defaults to 0.0.
|
132 |
+
|
133 |
+
Return:
|
134 |
+
Dictionary containing information about this solver step.
|
135 |
+
"""
|
136 |
+
|
137 |
+
self._iter_current += 1
|
138 |
+
ego_control_input = torch.normal(
|
139 |
+
self._control_input_mean.expand(
|
140 |
+
self.params.num_control_samples, -1, -1, -1
|
141 |
+
),
|
142 |
+
self._control_input_std.expand(self.params.num_control_samples, -1, -1, -1),
|
143 |
+
)
|
144 |
+
if verbose:
|
145 |
+
print(f"**Cross Entropy Iteration {self._iter_current}")
|
146 |
+
print(
|
147 |
+
f"****Drawring ego's control input samples of {ego_control_input.shape}"
|
148 |
+
)
|
149 |
+
ego_state_current = ego_state_history[..., -1]
|
150 |
+
ego_state_future = self.dynamics_model.simulate(
|
151 |
+
ego_state_current, ego_control_input
|
152 |
+
)
|
153 |
+
if verbose:
|
154 |
+
print(f"****Simulating ego's future state trajectory")
|
155 |
+
|
156 |
+
# state starts with x, y, angle, vx, vy
|
157 |
+
tracking_cost = self.tracking_cost(
|
158 |
+
ego_state_future.position,
|
159 |
+
ego_state_target_trajectory.position,
|
160 |
+
ego_state_target_trajectory.velocity,
|
161 |
+
)
|
162 |
+
if verbose:
|
163 |
+
print(
|
164 |
+
f"****Computing tracking cost of {tracking_cost.shape} for the control input samples"
|
165 |
+
)
|
166 |
+
|
167 |
+
# state starts with x, y
|
168 |
+
interaction_cost = get_interaction_cost(
|
169 |
+
ego_state_future,
|
170 |
+
ado_state_future_samples,
|
171 |
+
self.interaction_cost,
|
172 |
+
)
|
173 |
+
if verbose:
|
174 |
+
print(
|
175 |
+
f"****Computing interaction cost of {interaction_cost.shape} for the control input samples"
|
176 |
+
)
|
177 |
+
interaction_risk = evaluate_risk(
|
178 |
+
risk_level,
|
179 |
+
interaction_cost,
|
180 |
+
weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost),
|
181 |
+
self.risk_estimator,
|
182 |
+
)
|
183 |
+
|
184 |
+
total_risk = interaction_risk + tracking_cost
|
185 |
+
elite_ego_control_input, elite_total_risk = self._get_elites(
|
186 |
+
ego_control_input, total_risk
|
187 |
+
)
|
188 |
+
if verbose:
|
189 |
+
print(f"****Selecting {self.params.num_elite} elite samples")
|
190 |
+
print(f"****Elite Total_Risk Information: {elite_total_risk}")
|
191 |
+
|
192 |
+
info = dict(
|
193 |
+
iteration=self._iter_current,
|
194 |
+
control_input_mean=self._control_input_mean.detach().cpu().numpy().copy(),
|
195 |
+
control_input_std=self._control_input_std.detach().cpu().numpy().copy(),
|
196 |
+
ego_state_future=ego_state_future.get_states(5)
|
197 |
+
.detach()
|
198 |
+
.cpu()
|
199 |
+
.numpy()
|
200 |
+
.copy(),
|
201 |
+
ado_state_future_samples=ado_state_future_samples.get_states(5)
|
202 |
+
.detach()
|
203 |
+
.cpu()
|
204 |
+
.numpy()
|
205 |
+
.copy(),
|
206 |
+
sample_weights=weights.detach().cpu().numpy().copy(),
|
207 |
+
tracking_cost=tracking_cost.detach().cpu().numpy().copy(),
|
208 |
+
interaction_cost=interaction_cost.detach().cpu().numpy().copy(),
|
209 |
+
total_risk=total_risk.detach().cpu().numpy().copy(),
|
210 |
+
)
|
211 |
+
|
212 |
+
self._update_control_distribution(elite_ego_control_input)
|
213 |
+
if verbose:
|
214 |
+
print("****Updating ego's control distribution")
|
215 |
+
|
216 |
+
return info
|
217 |
+
|
218 |
+
def solve(
|
219 |
+
self,
|
220 |
+
predictor: LitTrajectoryPredictor,
|
221 |
+
ego_state_history: AbstractState,
|
222 |
+
ego_state_target_trajectory: AbstractState,
|
223 |
+
ado_state_history: AbstractState,
|
224 |
+
normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
|
225 |
+
num_prediction_samples: int = 1,
|
226 |
+
verbose: bool = False,
|
227 |
+
risk_level: float = 0.0,
|
228 |
+
resample_prediction: bool = False,
|
229 |
+
risk_in_predictor: bool = False,
|
230 |
+
) -> List[Dict]:
|
231 |
+
"""Performs Cross Entropy optimization of ego's control input
|
232 |
+
|
233 |
+
Args:
|
234 |
+
predictor: LitTrajectoryPredictor object
|
235 |
+
ego_state_history: (num_agents, num_steps, state_dim) ego state history
|
236 |
+
ego_state_target_trajectory: (num_agents, num_steps_future, state_dim) ego target
|
237 |
+
state trajectory
|
238 |
+
ado_state_history: (num_agents, num_steps, state_dim) ado state history
|
239 |
+
normalizer: function that takes in an unnormalized trajectory and that outputs the
|
240 |
+
normalized trajectory and the offset in this order
|
241 |
+
num_prediction_samples: number of prediction samples. Defaults to 1.
|
242 |
+
verbose (optional): Print progress. Defaults to False.
|
243 |
+
risk_level (optional): a risk-level float for the entire prediction-planning pipeline.
|
244 |
+
If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0.
|
245 |
+
resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy
|
246 |
+
iteration. Defaults to False.
|
247 |
+
risk_in_predictor (optional): If True, risk-biased prediction is used and the solver
|
248 |
+
becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes
|
249 |
+
risk-sensitive. Defaults to False.
|
250 |
+
|
251 |
+
Return:
|
252 |
+
List of dictionaries each containing information about the corresponding solver step.
|
253 |
+
"""
|
254 |
+
if risk_level == 0.0:
|
255 |
+
risk_level_planner, risk_level_predictor = 0.0, 0.0
|
256 |
+
else:
|
257 |
+
if risk_in_predictor:
|
258 |
+
risk_level_planner, risk_level_predictor = 0.0, risk_level
|
259 |
+
else:
|
260 |
+
risk_level_planner, risk_level_predictor = risk_level, 0.0
|
261 |
+
self.reset()
|
262 |
+
infos = []
|
263 |
+
ego_state_future = self.dynamics_model.simulate(
|
264 |
+
ego_state_history[..., -1],
|
265 |
+
self.control_sequence,
|
266 |
+
)
|
267 |
+
for iter in range(self.params.iter_max):
|
268 |
+
assert iter == self._iter_current
|
269 |
+
if resample_prediction or self._iter_current == 0:
|
270 |
+
ado_state_future_samples, weights = self.sample_prediction(
|
271 |
+
predictor,
|
272 |
+
ado_state_history,
|
273 |
+
normalizer,
|
274 |
+
ego_state_history,
|
275 |
+
ego_state_future,
|
276 |
+
num_prediction_samples,
|
277 |
+
risk_level_predictor,
|
278 |
+
)
|
279 |
+
self._latest_ado_position_future_samples = ado_state_future_samples
|
280 |
+
info = self.step(
|
281 |
+
ego_state_history,
|
282 |
+
ego_state_target_trajectory,
|
283 |
+
ado_state_future_samples,
|
284 |
+
weights,
|
285 |
+
verbose=verbose,
|
286 |
+
risk_level=risk_level_planner,
|
287 |
+
)
|
288 |
+
infos.append(info)
|
289 |
+
if self.params.mean_warm_start:
|
290 |
+
self.control_input_mean_init[:, :-1] = (
|
291 |
+
self._control_input_mean[:, 1:].detach().clone()
|
292 |
+
)
|
293 |
+
return infos
|
294 |
+
|
295 |
+
@property
|
296 |
+
def control_sequence(self) -> torch.Tensor:
|
297 |
+
"""Returns the planned control sequence, which is a detached copy of the control input mean
|
298 |
+
tensor
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
(num_steps_future, control_dim) control sequence tensor
|
302 |
+
"""
|
303 |
+
return self._control_input_mean.detach().clone()
|
304 |
+
|
305 |
+
@staticmethod
|
306 |
+
def sample_prediction(
|
307 |
+
predictor: LitTrajectoryPredictor,
|
308 |
+
ado_state_history: AbstractState,
|
309 |
+
normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
|
310 |
+
ego_state_history: AbstractState,
|
311 |
+
ego_state_future: AbstractState,
|
312 |
+
num_prediction_samples: int = 1,
|
313 |
+
risk_level: float = 0.0,
|
314 |
+
) -> Tuple[AbstractState, torch.Tensor]:
|
315 |
+
"""Sample prediction from the predictor given the history, normalizer, and the desired
|
316 |
+
risk-level
|
317 |
+
|
318 |
+
Args:
|
319 |
+
predictor: LitTrajectoryPredictor object
|
320 |
+
ado_state_history: (num_agents, num_steps, state_dim) tensor of ado position history
|
321 |
+
normalizer: function that takes in an unnormalized trajectory and that outputs the
|
322 |
+
normalized trajectory and the offset in this order
|
323 |
+
ego_state_history: (num_agents, num_steps , state_dim) tensor of ego position history or future
|
324 |
+
ego_state_future: (num_agents, num_steps_future, state_dim) tensor of ego position history or future
|
325 |
+
num_prediction_samples (optional): number of prediction samples. Defaults to 1.
|
326 |
+
risk_level (optional): a risk-level float for the predictor. If 0.0, risk-neutral
|
327 |
+
prediction is sampled. Defaults to 0.0.
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
state samples of shape (num_agents, num_prediction_samples, num_steps_future)
|
331 |
+
probability weights of the samples of shape (num_agents, num_prediction_samples)
|
332 |
+
"""
|
333 |
+
ado_position_history_normalized, offset = normalizer(
|
334 |
+
ado_state_history.get_states(predictor.dynamic_state_dim)
|
335 |
+
.unsqueeze(0)
|
336 |
+
.expand(num_prediction_samples, -1, -1, -1)
|
337 |
+
)
|
338 |
+
|
339 |
+
x = ado_position_history_normalized.clone()
|
340 |
+
mask_x = torch.ones_like(x[..., 0])
|
341 |
+
map = torch.empty(num_prediction_samples, 0, 0, 2, device=x.device)
|
342 |
+
mask_map = torch.empty(num_prediction_samples, 0, 0, device=x.device)
|
343 |
+
|
344 |
+
batch = (
|
345 |
+
x,
|
346 |
+
mask_x,
|
347 |
+
map,
|
348 |
+
mask_map,
|
349 |
+
offset,
|
350 |
+
ego_state_history.get_states(predictor.dynamic_state_dim)
|
351 |
+
.unsqueeze(0)
|
352 |
+
.expand(num_prediction_samples, -1, -1, -1),
|
353 |
+
ego_state_future.get_states(predictor.dynamic_state_dim)
|
354 |
+
.unsqueeze(0)
|
355 |
+
.expand(num_prediction_samples, -1, -1, -1),
|
356 |
+
)
|
357 |
+
|
358 |
+
ado_position_future_samples, weights = predictor.predict_step(
|
359 |
+
batch,
|
360 |
+
0,
|
361 |
+
risk_level=risk_level,
|
362 |
+
return_weights=True,
|
363 |
+
)
|
364 |
+
ado_position_future_samples = ado_position_future_samples.detach().cpu()
|
365 |
+
weights = weights.detach().cpu()
|
366 |
+
|
367 |
+
return to_state(ado_position_future_samples, predictor.dt), weights
|
368 |
+
|
369 |
+
def fetch_latest_prediction(self):
|
370 |
+
if self._latest_ado_position_future_samples is not None:
|
371 |
+
return self._latest_ado_position_future_samples
|
372 |
+
else:
|
373 |
+
return None
|
374 |
+
|
375 |
+
def _get_elites(
|
376 |
+
self, control_input: torch.Tensor, risk: torch.Tensor
|
377 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
378 |
+
"""Selects elite control input based on corresponding risk (lower the better)
|
379 |
+
|
380 |
+
Args:
|
381 |
+
control_input: (num_control_samples, num_agents, num_steps_future, control_dim) control samples
|
382 |
+
risk: (num_control_samples, num_agents) risk tensor
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
elite_control_input: (num_elite, num_agents, num_steps_future, control_dim) elite control
|
386 |
+
elite_risk: (num_elite, num_agents) elite risk
|
387 |
+
"""
|
388 |
+
num_control_samples = self.params.num_control_samples
|
389 |
+
assert (
|
390 |
+
control_input.shape[0] == num_control_samples
|
391 |
+
), f"size of control_input tensor must be {num_control_samples} at dimension 0"
|
392 |
+
assert (
|
393 |
+
risk.shape[0] == num_control_samples
|
394 |
+
), f"size of risk tensor must be {num_control_samples} at dimension 0"
|
395 |
+
|
396 |
+
_, sorted_risk_indices = torch.sort(risk, dim=0)
|
397 |
+
elite_control_input = control_input[
|
398 |
+
sorted_risk_indices[: self.params.num_elite], np.arange(risk.shape[1])
|
399 |
+
]
|
400 |
+
elite_risk = risk[
|
401 |
+
sorted_risk_indices[: self.params.num_elite], np.arange(risk.shape[1])
|
402 |
+
]
|
403 |
+
return elite_control_input, elite_risk
|
404 |
+
|
405 |
+
def _update_control_distribution(self, elite_control_input: torch.Tensor) -> None:
|
406 |
+
"""Updates control input distribution using elites
|
407 |
+
|
408 |
+
Args:
|
409 |
+
elite_control_input: (num_elite, num_steps_future, control_dim) elite control
|
410 |
+
"""
|
411 |
+
num_elite, smoothing_factor = (
|
412 |
+
self.params.num_elite,
|
413 |
+
self.params.smoothing_factor,
|
414 |
+
)
|
415 |
+
assert (
|
416 |
+
elite_control_input.shape[0] == num_elite
|
417 |
+
), f"size of elite_control_input tensor must be {num_elite} at dimension 0"
|
418 |
+
|
419 |
+
elite_control_input_mean = elite_control_input.mean(dim=0, keepdim=False)
|
420 |
+
if num_elite < 2:
|
421 |
+
elite_control_input_std = torch.zeros_like(elite_control_input_mean)
|
422 |
+
else:
|
423 |
+
elite_control_input_std = elite_control_input.std(dim=0, keepdim=False)
|
424 |
+
self._control_input_mean = (
|
425 |
+
1.0 - smoothing_factor
|
426 |
+
) * elite_control_input_mean + smoothing_factor * self._control_input_mean
|
427 |
+
self._control_input_std = (
|
428 |
+
1.0 - smoothing_factor
|
429 |
+
) * elite_control_input_std + smoothing_factor * self._control_input_std
|
risk_biased/predictors/biased_predictor.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from functools import partial
|
3 |
+
from typing import Callable, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
|
6 |
+
from einops import repeat
|
7 |
+
from mmcv import Config
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from risk_biased.models.cvae_params import CVAEParams
|
12 |
+
from risk_biased.models.biased_cvae_model import (
|
13 |
+
cvae_factory,
|
14 |
+
)
|
15 |
+
|
16 |
+
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
|
17 |
+
from risk_biased.utils.risk import get_risk_estimator
|
18 |
+
from risk_biased.utils.risk import get_risk_level_sampler
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class LitTrajectoryPredictorParams:
|
23 |
+
"""
|
24 |
+
cvae_params: CVAEParams class defining the necessary parameters for the CVAE model
|
25 |
+
risk distribution: dict of string and values defining the risk distribution to use
|
26 |
+
risk_estimator: dict of string and values defining the risk estimator to use
|
27 |
+
kl_weight: float defining the weight of the KL term in the loss function
|
28 |
+
kl_threshold: float defining the threshold to apply when computing kl divergence (avoid posterior collapse)
|
29 |
+
risk_weight: float defining the weight of the risk term in the loss function
|
30 |
+
n_mc_samples_risk: int defining the number of Monte Carlo samples to use when estimating the risk
|
31 |
+
n_mc_samples_biased: int defining the number of Monte Carlo samples to use when estimating the expected biased cost
|
32 |
+
dt: float defining the duration between two consecutive time steps
|
33 |
+
learning_rate: float defining the learning rate for the optimizer
|
34 |
+
use_risk_constraint: bool defining whether to use the risk constrained optimization procedure
|
35 |
+
risk_constraint_update_every_n_epoch: int defining the number of epochs between two risk weight updates
|
36 |
+
risk_constraint_weight_update_factor: float defining the factor by which the risk weight is multiplied at each update
|
37 |
+
risk_constraint_weight_maximum: float defining the maximum value of the risk weight
|
38 |
+
num_samples_min_fde: int defining the number of samples to use when estimating the minimum FDE
|
39 |
+
condition_on_ego_future: bool defining whether to condition the biasing on the ego future trajectory (else on the ego past)
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
cvae_params: CVAEParams
|
44 |
+
risk_distribution: dict
|
45 |
+
risk_estimator: dict
|
46 |
+
kl_weight: float
|
47 |
+
kl_threshold: float
|
48 |
+
risk_weight: float
|
49 |
+
n_mc_samples_risk: int
|
50 |
+
n_mc_samples_biased: int
|
51 |
+
dt: float
|
52 |
+
learning_rate: float
|
53 |
+
use_risk_constraint: bool
|
54 |
+
risk_constraint_update_every_n_epoch: int
|
55 |
+
risk_constraint_weight_update_factor: float
|
56 |
+
risk_constraint_weight_maximum: float
|
57 |
+
num_samples_min_fde: int
|
58 |
+
condition_on_ego_future: bool
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def from_config(cfg: Config):
|
62 |
+
cvae_params = CVAEParams.from_config(cfg)
|
63 |
+
return LitTrajectoryPredictorParams(
|
64 |
+
risk_distribution=cfg.risk_distribution,
|
65 |
+
risk_estimator=cfg.risk_estimator,
|
66 |
+
kl_weight=cfg.kl_weight,
|
67 |
+
kl_threshold=cfg.kl_threshold,
|
68 |
+
risk_weight=cfg.risk_weight,
|
69 |
+
n_mc_samples_risk=cfg.n_mc_samples_risk,
|
70 |
+
n_mc_samples_biased=cfg.n_mc_samples_biased,
|
71 |
+
dt=cfg.dt,
|
72 |
+
learning_rate=cfg.learning_rate,
|
73 |
+
cvae_params=cvae_params,
|
74 |
+
use_risk_constraint=cfg.use_risk_constraint,
|
75 |
+
risk_constraint_update_every_n_epoch=cfg.risk_constraint_update_every_n_epoch,
|
76 |
+
risk_constraint_weight_update_factor=cfg.risk_constraint_weight_update_factor,
|
77 |
+
risk_constraint_weight_maximum=cfg.risk_constraint_weight_maximum,
|
78 |
+
num_samples_min_fde=cfg.num_samples_min_fde,
|
79 |
+
condition_on_ego_future=cfg.condition_on_ego_future,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class LitTrajectoryPredictor(pl.LightningModule):
|
84 |
+
"""Pytorch Lightning Module for Trajectory Prediction with the biased cvae model
|
85 |
+
|
86 |
+
Args:
|
87 |
+
params : dataclass object containing the necessary parameters
|
88 |
+
cost_params: dataclass object defining the TTC cost function
|
89 |
+
unnormalizer: function that takes in a trajectory and an offset and that outputs the
|
90 |
+
unnormalized trajectory
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
params: LitTrajectoryPredictorParams,
|
96 |
+
cost_params: TTCCostParams,
|
97 |
+
unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
98 |
+
) -> None:
|
99 |
+
super().__init__()
|
100 |
+
model = cvae_factory(
|
101 |
+
params.cvae_params,
|
102 |
+
cost_function=TTCCostTorch(cost_params),
|
103 |
+
risk_estimator=get_risk_estimator(params.risk_estimator),
|
104 |
+
training_mode="cvae",
|
105 |
+
)
|
106 |
+
self.model = model
|
107 |
+
self.params = params
|
108 |
+
self._unnormalize_trajectory = unnormalizer
|
109 |
+
self.set_training_mode("cvae")
|
110 |
+
|
111 |
+
self.learning_rate = params.learning_rate
|
112 |
+
self.num_samples_min_fde = params.num_samples_min_fde
|
113 |
+
|
114 |
+
self.dynamic_state_dim = params.cvae_params.dynamic_state_dim
|
115 |
+
self.dt = params.cvae_params.dt
|
116 |
+
|
117 |
+
self.use_risk_constraint = params.use_risk_constraint
|
118 |
+
self.risk_weight = params.risk_weight
|
119 |
+
self.risk_weight_ratio = params.risk_weight / params.kl_weight
|
120 |
+
self.kl_weight = params.kl_weight
|
121 |
+
if self.use_risk_constraint:
|
122 |
+
self.risk_constraint_update_every_n_epoch = (
|
123 |
+
params.risk_constraint_update_every_n_epoch
|
124 |
+
)
|
125 |
+
self.risk_constraint_weight_update_factor = (
|
126 |
+
params.risk_constraint_weight_update_factor
|
127 |
+
)
|
128 |
+
self.risk_constraint_weight_maximum = params.risk_constraint_weight_maximum
|
129 |
+
|
130 |
+
self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
|
131 |
+
|
132 |
+
def set_training_mode(self, training_mode: str):
|
133 |
+
self.model.set_training_mode(training_mode)
|
134 |
+
self.partial_get_loss = partial(
|
135 |
+
self.model.get_loss,
|
136 |
+
kl_threshold=self.params.kl_threshold,
|
137 |
+
n_samples_risk=self.params.n_mc_samples_risk,
|
138 |
+
n_samples_biased=self.params.n_mc_samples_biased,
|
139 |
+
dt=self.params.dt,
|
140 |
+
unnormalizer=self._unnormalize_trajectory,
|
141 |
+
)
|
142 |
+
|
143 |
+
def _get_loss(
|
144 |
+
self,
|
145 |
+
x: torch.Tensor,
|
146 |
+
mask_x: torch.Tensor,
|
147 |
+
map: torch.Tensor,
|
148 |
+
mask_map: torch.Tensor,
|
149 |
+
y: torch.Tensor,
|
150 |
+
mask_y: torch.Tensor,
|
151 |
+
mask_loss: torch.Tensor,
|
152 |
+
x_ego: torch.Tensor,
|
153 |
+
y_ego: torch.Tensor,
|
154 |
+
offset: Optional[torch.Tensor] = None,
|
155 |
+
risk_level: Optional[torch.Tensor] = None,
|
156 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]], dict]:
|
157 |
+
"""Compute loss based on trajectory history x and future y
|
158 |
+
|
159 |
+
Args:
|
160 |
+
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
|
161 |
+
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
|
162 |
+
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
|
163 |
+
mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
|
164 |
+
y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
|
165 |
+
mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask.
|
166 |
+
mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
|
167 |
+
should be computed and to False where it shouldn't
|
168 |
+
offset : (batch_size, num_agents, state_dim) offset position from ego
|
169 |
+
risk_level : (batch_size, num_agents) tensor of risk levels desired for future trajectories
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
Union[torch.Tensor, Tuple[torch.Tensor, ...]]: (1,) loss tensor or tuple of
|
173 |
+
loss tensors
|
174 |
+
dict: dict that contains values to be logged
|
175 |
+
"""
|
176 |
+
return self.partial_get_loss(
|
177 |
+
x=x,
|
178 |
+
mask_x=mask_x,
|
179 |
+
map=map,
|
180 |
+
mask_map=mask_map,
|
181 |
+
y=y,
|
182 |
+
mask_y=mask_y,
|
183 |
+
mask_loss=mask_loss,
|
184 |
+
offset=offset,
|
185 |
+
risk_level=risk_level,
|
186 |
+
x_ego=x_ego,
|
187 |
+
y_ego=y_ego,
|
188 |
+
risk_weight=self.risk_weight,
|
189 |
+
kl_weight=self.kl_weight,
|
190 |
+
)
|
191 |
+
|
192 |
+
def log_with_prefix(
|
193 |
+
self,
|
194 |
+
log_dict: dict,
|
195 |
+
prefix: Optional[str] = None,
|
196 |
+
on_step: Optional[bool] = None,
|
197 |
+
on_epoch: Optional[bool] = None,
|
198 |
+
) -> None:
|
199 |
+
"""log entries in log_dict while optinally adding "<prefix>/" to its keys
|
200 |
+
|
201 |
+
Args:
|
202 |
+
log_dict: dict that contains values to be logged
|
203 |
+
prefix: prefix to be added to keys
|
204 |
+
on_step: if True logs at this step. None auto-logs at the training_step but not
|
205 |
+
validation/test_step
|
206 |
+
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test
|
207 |
+
step but not training_step
|
208 |
+
"""
|
209 |
+
if prefix is None:
|
210 |
+
prefix = ""
|
211 |
+
else:
|
212 |
+
prefix += "/"
|
213 |
+
|
214 |
+
for (metric, value) in log_dict.items():
|
215 |
+
metric = prefix + metric
|
216 |
+
self.log(metric, value, on_step=on_step, on_epoch=on_epoch)
|
217 |
+
|
218 |
+
def configure_optimizers(
|
219 |
+
self,
|
220 |
+
) -> Union[torch.optim.Optimizer, List[torch.optim.Optimizer]]:
|
221 |
+
"""Configure optimizer for PyTorch-Lightning
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
torch.optim.Optimizer: optimizer to be used for training
|
225 |
+
"""
|
226 |
+
if isinstance(self.model.get_parameters(), list):
|
227 |
+
self._optimizers = [
|
228 |
+
torch.optim.Adam(params, lr=self.learning_rate)
|
229 |
+
for params in self.model.get_parameters()
|
230 |
+
]
|
231 |
+
else:
|
232 |
+
self._optimizers = [
|
233 |
+
torch.optim.Adam(self.model.get_parameters(), lr=self.learning_rate)
|
234 |
+
]
|
235 |
+
return self._optimizers
|
236 |
+
|
237 |
+
def training_step(
|
238 |
+
self,
|
239 |
+
batch: Tuple[
|
240 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
241 |
+
],
|
242 |
+
batch_idx: int,
|
243 |
+
) -> dict:
|
244 |
+
"""Training step definition for PyTorch-Lightning
|
245 |
+
|
246 |
+
Args:
|
247 |
+
batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
|
248 |
+
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
|
249 |
+
(batch_size, num_agents, num_steps_future, state_dim), # future trajectory
|
250 |
+
(batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data
|
251 |
+
(batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted
|
252 |
+
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
|
253 |
+
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
|
254 |
+
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
|
255 |
+
(batch_size, 1, num_steps, state_dim), # ego past trajectory
|
256 |
+
(batch_size, 1, num_steps_future, state_dim)] # ego future trajectory
|
257 |
+
batch_idx : batch_idx to be used by PyTorch-Lightning
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
dict: dict of outputs containing loss
|
261 |
+
"""
|
262 |
+
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
|
263 |
+
risk_level = repeat(
|
264 |
+
self._risk_sampler.sample(x.shape[0], x.device),
|
265 |
+
"b -> b num_agents",
|
266 |
+
num_agents=x.shape[1],
|
267 |
+
)
|
268 |
+
loss, log_dict = self._get_loss(
|
269 |
+
x=x,
|
270 |
+
mask_x=mask_x,
|
271 |
+
map=map,
|
272 |
+
mask_map=mask_map,
|
273 |
+
y=y,
|
274 |
+
mask_y=mask_y,
|
275 |
+
mask_loss=mask_loss,
|
276 |
+
offset=offset,
|
277 |
+
risk_level=risk_level,
|
278 |
+
x_ego=x_ego,
|
279 |
+
y_ego=y_ego,
|
280 |
+
)
|
281 |
+
if isinstance(loss, tuple):
|
282 |
+
loss = sum(loss)
|
283 |
+
self.log_with_prefix(log_dict, prefix="train", on_step=True, on_epoch=False)
|
284 |
+
|
285 |
+
return {"loss": loss}
|
286 |
+
|
287 |
+
def training_epoch_end(self, outputs: List[dict]) -> None:
|
288 |
+
"""Called at the end of the training epoch with the outputs of all training steps
|
289 |
+
|
290 |
+
Args:
|
291 |
+
outputs: list of outputs of all training steps in the current epoch
|
292 |
+
"""
|
293 |
+
if self.use_risk_constraint:
|
294 |
+
if (
|
295 |
+
self.model.training_mode == "bias"
|
296 |
+
and (self.trainer.current_epoch + 1)
|
297 |
+
% self.risk_constraint_update_every_n_epoch
|
298 |
+
== 0
|
299 |
+
):
|
300 |
+
self.risk_weight_ratio *= self.risk_constraint_weight_update_factor
|
301 |
+
if self.risk_weight_ratio < self.risk_constraint_weight_maximum:
|
302 |
+
sum_weight = self.risk_weight + self.kl_weight
|
303 |
+
self.risk_weight = (
|
304 |
+
sum_weight
|
305 |
+
* self.risk_weight_ratio
|
306 |
+
/ (1 + self.risk_weight_ratio)
|
307 |
+
)
|
308 |
+
self.kl_weight = sum_weight / (1 + self.risk_weight_ratio)
|
309 |
+
# self.risk_weight *= self.risk_constraint_weight_update_factor
|
310 |
+
# if self.risk_weight > self.risk_constraint_weight_maximum:
|
311 |
+
# self.risk_weight = self.risk_constraint_weight_maximum
|
312 |
+
|
313 |
+
def _get_risk_tensor(
|
314 |
+
self,
|
315 |
+
batch_size: int,
|
316 |
+
num_agents: int,
|
317 |
+
device: torch.device,
|
318 |
+
risk_level: Optional[torch.Tensor] = None,
|
319 |
+
):
|
320 |
+
"""This function is used to reformat different possible formattings of risk_level input arguments into a tensor of shape (batch_size).
|
321 |
+
If given a tensor the same tensor is returned.
|
322 |
+
If given a float value, a tensor of this value is returned.
|
323 |
+
If given None, a tensor filled with random samples is returned.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
batch_size : desired batch size
|
327 |
+
device : device on which we want to store risk
|
328 |
+
risk_level : The risk level as a tensor, a float value or None
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
_type_: _description_
|
332 |
+
"""
|
333 |
+
if risk_level is not None:
|
334 |
+
if isinstance(risk_level, float):
|
335 |
+
risk_level = (
|
336 |
+
torch.ones(batch_size, num_agents, device=device) * risk_level
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
risk_level = risk_level.to(device)
|
340 |
+
else:
|
341 |
+
risk_level = None
|
342 |
+
|
343 |
+
return risk_level
|
344 |
+
|
345 |
+
def validation_step(
|
346 |
+
self,
|
347 |
+
batch: Tuple[
|
348 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
349 |
+
],
|
350 |
+
batch_idx: int,
|
351 |
+
risk_level: float = 1.0,
|
352 |
+
) -> dict:
|
353 |
+
"""Validation step definition for PyTorch-Lightning
|
354 |
+
|
355 |
+
Args:
|
356 |
+
batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
|
357 |
+
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
|
358 |
+
(batch_size, num_agents, num_steps_future, state_dim), # future trajectory
|
359 |
+
(batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data
|
360 |
+
(batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted
|
361 |
+
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
|
362 |
+
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
|
363 |
+
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
|
364 |
+
(batch_size, 1, num_steps, state_dim), # ego past trajectory
|
365 |
+
(batch_size, 1, num_steps_future, state_dim)] # ego future trajectory
|
366 |
+
batch_idx : batch_idx to be used by PyTorch-Lightning
|
367 |
+
risk_level : optional desired risk level
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
dict: dict of outputs containing loss
|
371 |
+
"""
|
372 |
+
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
|
373 |
+
|
374 |
+
risk_level = self._get_risk_tensor(
|
375 |
+
x.shape[0], x.shape[1], x.device, risk_level=risk_level
|
376 |
+
)
|
377 |
+
self.model.eval()
|
378 |
+
log_dict_accuracy = self.model.get_prediction_accuracy(
|
379 |
+
x=x,
|
380 |
+
mask_x=mask_x,
|
381 |
+
map=map,
|
382 |
+
mask_map=mask_map,
|
383 |
+
y=y,
|
384 |
+
mask_loss=mask_loss,
|
385 |
+
offset=offset,
|
386 |
+
x_ego=x_ego,
|
387 |
+
y_ego=y_ego,
|
388 |
+
unnormalizer=self._unnormalize_trajectory,
|
389 |
+
risk_level=risk_level,
|
390 |
+
num_samples_min_fde=self.num_samples_min_fde,
|
391 |
+
)
|
392 |
+
|
393 |
+
loss, log_dict_loss = self._get_loss(
|
394 |
+
x=x,
|
395 |
+
mask_x=mask_x,
|
396 |
+
map=map,
|
397 |
+
mask_map=mask_map,
|
398 |
+
y=y,
|
399 |
+
mask_y=mask_y,
|
400 |
+
mask_loss=mask_loss,
|
401 |
+
offset=offset,
|
402 |
+
risk_level=risk_level,
|
403 |
+
x_ego=x_ego,
|
404 |
+
y_ego=y_ego,
|
405 |
+
)
|
406 |
+
|
407 |
+
if isinstance(loss, tuple):
|
408 |
+
loss = sum(loss)
|
409 |
+
|
410 |
+
self.log_with_prefix(
|
411 |
+
dict(log_dict_accuracy, **log_dict_loss),
|
412 |
+
prefix="val",
|
413 |
+
on_step=False,
|
414 |
+
on_epoch=True,
|
415 |
+
)
|
416 |
+
self.model.train()
|
417 |
+
return {"loss": loss}
|
418 |
+
|
419 |
+
def test_step(
|
420 |
+
self,
|
421 |
+
batch: Tuple[
|
422 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
423 |
+
],
|
424 |
+
batch_idx: int,
|
425 |
+
risk_level: Optional[torch.Tensor] = None,
|
426 |
+
) -> dict:
|
427 |
+
"""Test step definition for PyTorch-Lightning
|
428 |
+
|
429 |
+
Args:
|
430 |
+
batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
|
431 |
+
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
|
432 |
+
(batch_size, num_agents, num_steps_future, state_dim), # future trajectory
|
433 |
+
(batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data
|
434 |
+
(batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted
|
435 |
+
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
|
436 |
+
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
|
437 |
+
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
|
438 |
+
(batch_size, 1, num_steps, state_dim), # ego past trajectory
|
439 |
+
(batch_size, 1, num_steps_future, state_dim)] # ego future trajectory
|
440 |
+
batch_idx : batch_idx to be used by PyTorch-Lightning
|
441 |
+
risk_level : optional desired risk level
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
dict: dict of outputs containing loss
|
445 |
+
"""
|
446 |
+
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
|
447 |
+
risk_level = self._get_risk_tensor(
|
448 |
+
x.shape[0], x.shape[1], x.device, risk_level=risk_level
|
449 |
+
)
|
450 |
+
loss, log_dict = self._get_loss(
|
451 |
+
x=x,
|
452 |
+
mask_x=mask_x,
|
453 |
+
map=map,
|
454 |
+
mask_map=mask_map,
|
455 |
+
y=y,
|
456 |
+
mask_y=mask_y,
|
457 |
+
mask_loss=mask_loss,
|
458 |
+
offset=offset,
|
459 |
+
risk_level=risk_level,
|
460 |
+
x_ego=x_ego,
|
461 |
+
y_ego=y_ego,
|
462 |
+
)
|
463 |
+
if isinstance(loss, tuple):
|
464 |
+
loss = sum(loss)
|
465 |
+
self.log_with_prefix(log_dict, prefix="test", on_step=False, on_epoch=True)
|
466 |
+
return {"loss": loss}
|
467 |
+
|
468 |
+
def predict_step(
|
469 |
+
self,
|
470 |
+
batch: Tuple[torch.Tensor, torch.Tensor],
|
471 |
+
batch_idx: int = 0,
|
472 |
+
risk_level: Optional[torch.Tensor] = None,
|
473 |
+
n_samples: int = 0,
|
474 |
+
return_weights: bool = False,
|
475 |
+
) -> torch.Tensor:
|
476 |
+
"""Predict step definition for PyTorch-Lightning
|
477 |
+
|
478 |
+
Args:
|
479 |
+
batch: [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
|
480 |
+
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
|
481 |
+
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
|
482 |
+
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
|
483 |
+
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
|
484 |
+
(batch_size, 1, num_steps, state_dim), # past trajectory of the ego agent in the scene
|
485 |
+
(batch_size, 1, num_steps_future, state_dim),] # future trajectory of the ego agent in the scene
|
486 |
+
batch_idx : batch_idx to be used by PyTorch-Lightning (unused here)
|
487 |
+
risk_level : optional desired risk level
|
488 |
+
n_samples: Number of samples to predict per agent
|
489 |
+
With value of 0 does not include the `n_samples` dim in the output.
|
490 |
+
return_weights: If True, also returns the sample weights
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
(batch_size, (n_samples), num_steps_future, state_dim) tensor
|
494 |
+
"""
|
495 |
+
x, mask_x, map, mask_map, offset, x_ego, y_ego = batch
|
496 |
+
risk_level = self._get_risk_tensor(
|
497 |
+
batch_size=x.shape[0],
|
498 |
+
num_agents=x.shape[1],
|
499 |
+
device=x.device,
|
500 |
+
risk_level=risk_level,
|
501 |
+
)
|
502 |
+
y_sampled, weights, _ = self.model(
|
503 |
+
x,
|
504 |
+
mask_x,
|
505 |
+
map,
|
506 |
+
mask_map,
|
507 |
+
offset=offset,
|
508 |
+
x_ego=x_ego,
|
509 |
+
y_ego=y_ego,
|
510 |
+
risk_level=risk_level,
|
511 |
+
n_samples=n_samples,
|
512 |
+
)
|
513 |
+
predict_sampled = self._unnormalize_trajectory(y_sampled, offset)
|
514 |
+
if return_weights:
|
515 |
+
return predict_sampled, weights
|
516 |
+
else:
|
517 |
+
return predict_sampled
|
518 |
+
|
519 |
+
def predict_loop_once(
|
520 |
+
self,
|
521 |
+
batch: Tuple[torch.Tensor, torch.Tensor],
|
522 |
+
batch_idx: int = 0,
|
523 |
+
risk_level: Optional[torch.Tensor] = None,
|
524 |
+
) -> torch.Tensor:
|
525 |
+
"""Predict with refinment:
|
526 |
+
A first prediction is done as in predict_step, however instead of unnormalize and return it,
|
527 |
+
it is fed to the encoder that wast trained to encode past and ground truth future.
|
528 |
+
Then the decoder is used again but its latent input sample is biased by the encoder
|
529 |
+
instead of being a sample of the prior distribution.
|
530 |
+
Then as in predict_step the result is unnormalized and returned.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
batch: [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
|
534 |
+
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
|
535 |
+
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
|
536 |
+
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
|
537 |
+
(batch_size, num_agents, state_dim),] # position offset of all agents relative to ego at present time
|
538 |
+
batch_idx : batch_idx to be used by PyTorch-Lightning (Unused here). Defaults to 0.
|
539 |
+
risk_level : optional desired risk level
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
torch.Tensor: (batch_size, num_steps_future, state_dim) tensor
|
543 |
+
"""
|
544 |
+
x, mask_x, map, mask_map, offset = batch
|
545 |
+
risk_level = self._get_risk_tensor(
|
546 |
+
x.shape[0], x.shape[1], x.device, risk_level=risk_level
|
547 |
+
)
|
548 |
+
y_sampled, _ = self.model(
|
549 |
+
x,
|
550 |
+
mask_x,
|
551 |
+
map,
|
552 |
+
mask_map,
|
553 |
+
offset=offset,
|
554 |
+
risk_level=risk_level,
|
555 |
+
)
|
556 |
+
mask_y = repeat(mask_x.any(-1), "b a -> b a f", f=y_sampled.shape[-2])
|
557 |
+
y_sampled, _ = self.model(
|
558 |
+
x,
|
559 |
+
mask_x,
|
560 |
+
map,
|
561 |
+
mask_map,
|
562 |
+
y_sampled,
|
563 |
+
mask_y,
|
564 |
+
offset=offset,
|
565 |
+
risk_level=risk_level,
|
566 |
+
)
|
567 |
+
predict_sampled = self._unnormalize_trajectory(y_sampled, offset=offset)
|
568 |
+
return predict_sampled
|
risk_biased/scene_dataset/__init__.py
ADDED
File without changes
|
risk_biased/scene_dataset/loaders.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader, TensorDataset
|
6 |
+
|
7 |
+
|
8 |
+
class SceneDataLoaders:
|
9 |
+
"""
|
10 |
+
This class loads a scene dataset and pre-process it (normalization, unnormalization)
|
11 |
+
|
12 |
+
Args:
|
13 |
+
state_dim : dimension of the observed state (2 for x,y position observation)
|
14 |
+
num_steps : number of observed steps
|
15 |
+
num_steps_future : number of steps in the future
|
16 |
+
batch_size: set data loader with this batch size
|
17 |
+
data_train: training dataset
|
18 |
+
data_val: validation dataset
|
19 |
+
data_test: test dataset
|
20 |
+
num_workers: number of workers to use for data loading
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
state_dim: int,
|
26 |
+
num_steps: int,
|
27 |
+
num_steps_future: int,
|
28 |
+
batch_size: int,
|
29 |
+
data_train: torch.Tensor,
|
30 |
+
data_val: torch.Tensor,
|
31 |
+
data_test: torch.Tensor,
|
32 |
+
num_workers: int = 0,
|
33 |
+
):
|
34 |
+
self._batch_size = batch_size
|
35 |
+
self._num_workers = num_workers
|
36 |
+
self._state_dim = state_dim
|
37 |
+
self._num_steps = num_steps
|
38 |
+
self._num_steps_future = num_steps_future
|
39 |
+
|
40 |
+
self._setup_datasets(data_train, data_val, data_test)
|
41 |
+
|
42 |
+
def train_dataloader(self, shuffle=True, drop_last=True) -> DataLoader:
|
43 |
+
"""Setup and return training DataLoader
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
DataLoader: training DataLoader
|
47 |
+
"""
|
48 |
+
data_size = self._data_train_past.shape[0]
|
49 |
+
# This is a didactic data loader that only defines minimalistic inputs.
|
50 |
+
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
|
51 |
+
train_loader = DataLoader(
|
52 |
+
dataset=TensorDataset(
|
53 |
+
self._data_train_past,
|
54 |
+
torch.ones_like(self._data_train_past[..., 0]), # Mask past
|
55 |
+
self._data_train_fut,
|
56 |
+
torch.ones_like(self._data_train_fut[..., 0]), # Mask fut
|
57 |
+
torch.ones_like(self._data_train_fut[..., 0]), # Mask loss
|
58 |
+
torch.empty(
|
59 |
+
data_size, 1, 0, 0, device=self._data_train_past.device
|
60 |
+
), # Map
|
61 |
+
torch.empty(
|
62 |
+
data_size, 1, 0, device=self._data_train_past.device
|
63 |
+
), # Mask map
|
64 |
+
self._offset_train,
|
65 |
+
self._data_train_ego_past,
|
66 |
+
self._data_train_ego_fut,
|
67 |
+
),
|
68 |
+
batch_size=self._batch_size,
|
69 |
+
shuffle=shuffle,
|
70 |
+
drop_last=drop_last,
|
71 |
+
num_workers=self._num_workers,
|
72 |
+
)
|
73 |
+
return train_loader
|
74 |
+
|
75 |
+
def val_dataloader(self, shuffle=False, drop_last=False) -> DataLoader:
|
76 |
+
"""Setup and return validation DataLoader
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
DataLoader: validation DataLoader
|
80 |
+
"""
|
81 |
+
data_size = self._data_val_past.shape[0]
|
82 |
+
# This is a didactic data loader that only defines minimalistic inputs.
|
83 |
+
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
|
84 |
+
val_loader = DataLoader(
|
85 |
+
dataset=TensorDataset(
|
86 |
+
self._data_val_past,
|
87 |
+
torch.ones_like(self._data_val_past[..., 0]), # Mask past
|
88 |
+
self._data_val_fut,
|
89 |
+
torch.ones_like(self._data_val_fut[..., 0]), # Mask fut
|
90 |
+
torch.ones_like(self._data_val_fut[..., 0]), # Mask loss
|
91 |
+
torch.zeros(
|
92 |
+
data_size, 1, 0, 0, device=self._data_val_past.device
|
93 |
+
), # Map
|
94 |
+
torch.ones(
|
95 |
+
data_size, 1, 0, device=self._data_val_past.device
|
96 |
+
), # Mask map
|
97 |
+
self._offset_val,
|
98 |
+
self._data_val_ego_past,
|
99 |
+
self._data_val_ego_fut,
|
100 |
+
),
|
101 |
+
batch_size=self._batch_size,
|
102 |
+
shuffle=shuffle,
|
103 |
+
drop_last=drop_last,
|
104 |
+
num_workers=self._num_workers,
|
105 |
+
)
|
106 |
+
return val_loader
|
107 |
+
|
108 |
+
def test_dataloader(self) -> DataLoader:
|
109 |
+
"""Setup and return test DataLoader
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
DataLoader: test DataLoader
|
113 |
+
"""
|
114 |
+
data_size = self._data_test_past.shape[0]
|
115 |
+
# This is a didactic data loader that only defines minimalistic inputs.
|
116 |
+
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
|
117 |
+
test_loader = DataLoader(
|
118 |
+
dataset=TensorDataset(
|
119 |
+
self._data_test_past,
|
120 |
+
torch.ones_like(self._data_test_past[..., 0]), # Mask
|
121 |
+
torch.zeros(
|
122 |
+
data_size, 0, 1, 0, device=self._data_test_past.device
|
123 |
+
), # Map
|
124 |
+
torch.ones(
|
125 |
+
data_size, 0, 1, device=self._data_test_past.device
|
126 |
+
), # Mask map
|
127 |
+
self._offset_test,
|
128 |
+
self._data_test_ego_past,
|
129 |
+
self._data_test_ego_fut,
|
130 |
+
),
|
131 |
+
batch_size=self._batch_size,
|
132 |
+
shuffle=False,
|
133 |
+
num_workers=self._num_workers,
|
134 |
+
)
|
135 |
+
return test_loader
|
136 |
+
|
137 |
+
def _setup_datasets(
|
138 |
+
self, data_train: torch.Tensor, data_val: torch.Tensor, data_test: torch.Tensor
|
139 |
+
):
|
140 |
+
"""Setup datasets: normalize and split into past future
|
141 |
+
Args:
|
142 |
+
data_train: training dataset
|
143 |
+
data_val: validation dataset
|
144 |
+
data_test: test dataset
|
145 |
+
"""
|
146 |
+
data_train, data_train_ego = data_train[0], data_train[1]
|
147 |
+
data_val, data_val_ego = data_val[0], data_val[1]
|
148 |
+
data_test, data_test_ego = data_test[0], data_test[1]
|
149 |
+
|
150 |
+
data_train, self._offset_train = self.normalize_trajectory(data_train)
|
151 |
+
data_val, self._offset_val = self.normalize_trajectory(data_val)
|
152 |
+
data_test, self._offset_test = self.normalize_trajectory(data_test)
|
153 |
+
# This is a didactic data loader that only defines minimalistic inputs.
|
154 |
+
# An extra dimension is added to account for the number of agents in the scene.
|
155 |
+
# In this minimal input there is only one but the model using the data expects any number of agents.
|
156 |
+
self._data_train_past, self._data_train_fut = self.split_trajectory(data_train)
|
157 |
+
self._data_val_past, self._data_val_fut = self.split_trajectory(data_val)
|
158 |
+
self._data_test_past, self._data_test_fut = self.split_trajectory(data_test)
|
159 |
+
|
160 |
+
self._data_train_ego_past, self._data_train_ego_fut = self.split_trajectory(
|
161 |
+
data_train_ego
|
162 |
+
)
|
163 |
+
self._data_val_ego_past, self._data_val_ego_fut = self.split_trajectory(
|
164 |
+
data_val_ego
|
165 |
+
)
|
166 |
+
self._data_test_ego_past, self._data_test_ego_fut = self.split_trajectory(
|
167 |
+
data_test_ego
|
168 |
+
)
|
169 |
+
|
170 |
+
def split_trajectory(
|
171 |
+
self, input: torch.Tensor
|
172 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
173 |
+
"""Split input trajectory into history and future
|
174 |
+
|
175 |
+
Args:
|
176 |
+
input : (batch_size, (n_agents), num_steps + num_steps_future, state_dim) tensor of
|
177 |
+
entire trajectory [x, y]
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
Tuple of history and future trajectories
|
181 |
+
"""
|
182 |
+
assert (
|
183 |
+
input.shape[-2] == self._num_steps + self._num_steps_future
|
184 |
+
), "trajectory length ({}) does not match the expected length".format(
|
185 |
+
input.shape[-2]
|
186 |
+
)
|
187 |
+
assert (
|
188 |
+
input.shape[-1] == self._state_dim
|
189 |
+
), "state dimension ({}) does no match the expected dimension".format(
|
190 |
+
input.shape[-1]
|
191 |
+
)
|
192 |
+
|
193 |
+
input_history, input_future = torch.split(
|
194 |
+
input, [self._num_steps, self._num_steps_future], dim=-2
|
195 |
+
)
|
196 |
+
return input_history, input_future
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def normalize_trajectory(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
200 |
+
"""Normalize input trajectory by subtracting initial state
|
201 |
+
|
202 |
+
Args:
|
203 |
+
input : (some_shape, n_agents, num_steps + num_steps_future, state_dim) tensor of
|
204 |
+
entire trajectory [x, y], or (some_shape, num_steps, state_dim) tensor of history x
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
Tuple of (normalized_trajectory, offset), where
|
208 |
+
normalized_trajectory has the same dimension as the input and offset is a
|
209 |
+
(some_shape, state_dim) tensor corresponding to the initial state
|
210 |
+
"""
|
211 |
+
offset = input[..., 0, :].clone()
|
212 |
+
|
213 |
+
return input - offset.unsqueeze(-2), offset
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def unnormalize_trajectory(
|
217 |
+
input: torch.Tensor, offset: torch.Tensor
|
218 |
+
) -> torch.Tensor:
|
219 |
+
"""Unnormalize trajectory by adding offset to input
|
220 |
+
|
221 |
+
Args:
|
222 |
+
input : (some_shape, (n_sample), num_steps_future, state_dim) tensor of future
|
223 |
+
trajectory y
|
224 |
+
offset : (some_shape, 2 or 4 or 5) tensor of offset to add to y
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Unnormalized trajectory that has the same size as input
|
228 |
+
"""
|
229 |
+
offset_dim = offset.shape[-1]
|
230 |
+
assert input.shape[-1] >= offset_dim
|
231 |
+
input_clone = input.clone()
|
232 |
+
if offset.ndim == 2:
|
233 |
+
batch_size, _ = offset.shape
|
234 |
+
assert input_clone.shape[0] == batch_size
|
235 |
+
|
236 |
+
input_clone[..., :offset_dim] = input_clone[
|
237 |
+
..., :offset_dim
|
238 |
+
] + offset.reshape(
|
239 |
+
[batch_size, *([1] * (input_clone.ndim - 2)), offset_dim]
|
240 |
+
)
|
241 |
+
elif offset.ndim == 3:
|
242 |
+
batch_size, num_agents, _ = offset.shape
|
243 |
+
assert input_clone.shape[0] == batch_size
|
244 |
+
assert input_clone.shape[1] == num_agents
|
245 |
+
|
246 |
+
input_clone[..., :offset_dim] = input_clone[
|
247 |
+
..., :offset_dim
|
248 |
+
] + offset.reshape(
|
249 |
+
[batch_size, num_agents, *([1] * (input_clone.ndim - 3)), offset_dim]
|
250 |
+
)
|
251 |
+
|
252 |
+
return input_clone
|
risk_biased/scene_dataset/pedestrian.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import Tensor
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
|
7 |
+
class RandomPedestrians:
|
8 |
+
"""
|
9 |
+
Batched random pedestrians.
|
10 |
+
There are two types of pedestrians, slow and fast ones.
|
11 |
+
Each pedestrian type is walking mainly at its constant favored speed but at each time step there is a probability that it changes its pace.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
batch_size: int number of scenes in the batch
|
15 |
+
dt: float time step to use in the trajectory sequence
|
16 |
+
fast_speed: float fast walking speed for the random pedestrian in meters/seconds
|
17 |
+
slow_speed: float slow walking speed for the random pedestrian in meters/seconds
|
18 |
+
p_change_pace: float probability that a slow (resp. fast) pedestrian walk at fast_speed (resp. slow_speed) at each time step
|
19 |
+
proportion_fast: float proportion of the pedestrians that are mainly walking at fast_speed
|
20 |
+
is_torch: bool set to True to produce Tensor batches and to False to produce numpy arrays
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
batch_size: int,
|
26 |
+
dt: float = 0.1,
|
27 |
+
fast_speed: float = 2,
|
28 |
+
slow_speed: float = 1,
|
29 |
+
p_change_pace: float = 0.1,
|
30 |
+
proportion_fast: float = 0.5,
|
31 |
+
is_torch: bool = False,
|
32 |
+
) -> None:
|
33 |
+
|
34 |
+
self.is_torch = is_torch
|
35 |
+
self.fast_speed: float = fast_speed
|
36 |
+
self.slow_speed: float = slow_speed
|
37 |
+
self.dt: float = dt
|
38 |
+
self.p_change_pace: float = p_change_pace
|
39 |
+
self.batch_size: int = batch_size
|
40 |
+
|
41 |
+
self.propotion_fast: float = proportion_fast
|
42 |
+
if self.is_torch:
|
43 |
+
self.is_fast_type: Tensor = torch.from_numpy(
|
44 |
+
np.random.binomial(1, self.propotion_fast, [batch_size, 1, 1]).astype(
|
45 |
+
"float32"
|
46 |
+
)
|
47 |
+
)
|
48 |
+
self.is_currently_fast: Tensor = self.is_fast_type.clone()
|
49 |
+
self.initial_position: Tensor = torch.zeros([batch_size, 1, 2])
|
50 |
+
self.position: Tensor = self.initial_position.clone()
|
51 |
+
self._angle: Tensor = (2 * torch.rand(batch_size, 1) - 1) * np.pi
|
52 |
+
self.unit_velocity: Tensor = torch.stack(
|
53 |
+
(torch.cos(self._angle), torch.sin(self._angle)), -1
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
self.is_fast_type: np.ndarray = np.random.binomial(
|
57 |
+
1, self.propotion_fast, [batch_size, 1, 1]
|
58 |
+
)
|
59 |
+
self.is_currently_fast: np.ndarray = self.is_fast_type.copy()
|
60 |
+
self.initial_position: np.ndarray = np.zeros([batch_size, 1, 2])
|
61 |
+
self.position: np.ndarray = self.initial_position.copy()
|
62 |
+
self._angle: np.ndarray = np.random.uniform(-np.pi, np.pi, (batch_size, 1))
|
63 |
+
self.unit_velocity: np.ndarray = np.stack(
|
64 |
+
(np.cos(self._angle), np.sin(self._angle)), -1
|
65 |
+
)
|
66 |
+
|
67 |
+
@property
|
68 |
+
def angle(self):
|
69 |
+
return self._angle
|
70 |
+
|
71 |
+
@angle.setter
|
72 |
+
def angle(self, angle: Union[np.ndarray, torch.Tensor]):
|
73 |
+
assert self.batch_size == angle.shape[0]
|
74 |
+
if self.is_torch:
|
75 |
+
assert isinstance(angle, torch.Tensor)
|
76 |
+
self._angle = angle
|
77 |
+
self.unit_velocity = torch.stack(
|
78 |
+
(torch.cos(self._angle), torch.sin(self._angle)), -1
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
assert isinstance(angle, np.ndarray)
|
82 |
+
self._angle = angle
|
83 |
+
self.unit_velocity = np.stack(
|
84 |
+
(np.cos(self._angle), np.sin(self._angle)), -1
|
85 |
+
)
|
86 |
+
|
87 |
+
def step(self) -> None:
|
88 |
+
"""
|
89 |
+
Forward one time step, update the speed selection and the current position.
|
90 |
+
"""
|
91 |
+
self.update_speed()
|
92 |
+
self.update_position()
|
93 |
+
|
94 |
+
def update_speed(self) -> None:
|
95 |
+
"""
|
96 |
+
Update the speed as a random selection between favored speed and the other speed with probability self.p_change_pace.
|
97 |
+
"""
|
98 |
+
if self.is_torch:
|
99 |
+
do_flip = (
|
100 |
+
torch.from_numpy(
|
101 |
+
np.random.binomial(1, self.p_change_pace, self.batch_size).astype(
|
102 |
+
"float32"
|
103 |
+
)
|
104 |
+
)
|
105 |
+
== 1
|
106 |
+
)
|
107 |
+
self.is_currently_fast = self.is_fast_type.clone()
|
108 |
+
else:
|
109 |
+
do_flip = np.random.binomial(1, self.p_change_pace, self.batch_size) == 1
|
110 |
+
self.is_currently_fast = self.is_fast_type.copy()
|
111 |
+
self.is_currently_fast[do_flip] = 1 - self.is_fast_type[do_flip]
|
112 |
+
|
113 |
+
def update_position(self) -> None:
|
114 |
+
"""
|
115 |
+
Update the position as current position + time_step*speed*(cos(angle), sin(angle))
|
116 |
+
"""
|
117 |
+
self.position += (
|
118 |
+
self.dt
|
119 |
+
* (
|
120 |
+
self.slow_speed
|
121 |
+
+ (self.fast_speed - self.slow_speed) * self.is_currently_fast
|
122 |
+
)
|
123 |
+
* self.unit_velocity
|
124 |
+
)
|
125 |
+
|
126 |
+
def travel_distance(self) -> Union[np.ndarray, Tensor]:
|
127 |
+
"""
|
128 |
+
Return the travel distance between initial position and current position.
|
129 |
+
"""
|
130 |
+
if self.is_torch:
|
131 |
+
return torch.sqrt(
|
132 |
+
torch.sum(torch.square(self.position - self.initial_position), -1)
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
return np.sqrt(np.sum(np.square(self.position - self.initial_position), -1))
|
136 |
+
|
137 |
+
def get_final_position(self, time: float) -> Union[np.ndarray, Tensor]:
|
138 |
+
"""
|
139 |
+
Return a sample of pedestrian final positions using their speed distribution.
|
140 |
+
(This is stochastic, different samples will produce different results).
|
141 |
+
Args:
|
142 |
+
time: The final time at which to get the position
|
143 |
+
Returns:
|
144 |
+
The batch of final positions
|
145 |
+
"""
|
146 |
+
num_steps = int(round(time / self.dt))
|
147 |
+
if self.is_torch:
|
148 |
+
cumulative_change_state = torch.from_numpy(
|
149 |
+
np.random.binomial(
|
150 |
+
num_steps, self.p_change_pace, [self.batch_size, 1, 1]
|
151 |
+
).astype("float32")
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
cumulative_change_state = np.random.binomial(
|
155 |
+
num_steps, self.p_change_pace, [self.batch_size, 1, 1]
|
156 |
+
)
|
157 |
+
|
158 |
+
num_fast_steps = (
|
159 |
+
num_steps - 2 * cumulative_change_state
|
160 |
+
) * self.is_fast_type + cumulative_change_state
|
161 |
+
|
162 |
+
return self.position + self.unit_velocity * self.dt * (
|
163 |
+
self.slow_speed * num_steps
|
164 |
+
+ (self.fast_speed - self.slow_speed) * num_fast_steps
|
165 |
+
)
|
risk_biased/scene_dataset/scene.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import os
|
3 |
+
from typing import Union, List, Optional
|
4 |
+
import warnings
|
5 |
+
import copy
|
6 |
+
|
7 |
+
from mmcv import Config
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
from risk_biased.scene_dataset.pedestrian import RandomPedestrians
|
14 |
+
from risk_biased.utils.torch_utils import torch_linspace
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class RandomSceneParams:
|
19 |
+
"""Dataclass that defines all the listed parameters that are necessary for a RandomScene object
|
20 |
+
|
21 |
+
Args:
|
22 |
+
batch_size: number of scenes in the batch
|
23 |
+
time_scene: time length of the scene in seconds
|
24 |
+
sample_times: list of times to get the positions
|
25 |
+
ego_ref_speed: constant reference speed of the ego vehicle in meters/seconds
|
26 |
+
ego_speed_init_low: lowest initial speed of the ego vehicle in meters/seconds
|
27 |
+
ego_speed_init_high: higest initial speed of the ego vehicle in meters/seconds
|
28 |
+
ego_acceleration_mean_low: lowest mean acceleration of the ego vehicle in m/s^2
|
29 |
+
ego_acceleration_mean_high: highest mean acceleration of the ego vehicle in m/s^2
|
30 |
+
ego_acceleration_std: std for acceleration of the ego vehicle in m/s^2
|
31 |
+
ego_length: length of the ego vehicle in meters
|
32 |
+
ego_width: width of the ego vehicle in meters
|
33 |
+
dt: time step to use in the trajectory sequence
|
34 |
+
fast_speed: fast walking speed for the random pedestrian in meters/seconds
|
35 |
+
slow_speed: slow walking speed for the random pedestrian in meters/seconds
|
36 |
+
p_change_pace: probability that a slow (resp. fast) pedestrian walk at fast_speed (resp. slow_speed) at each time step
|
37 |
+
proportion_fast: proportion of the pedestrians that are mainly walking at fast_speed
|
38 |
+
perception_noise_std: standard deviation of the gaussian noise that is affecting the position observations
|
39 |
+
"""
|
40 |
+
|
41 |
+
batch_size: int
|
42 |
+
time_scene: float
|
43 |
+
sample_times: list
|
44 |
+
ego_ref_speed: float
|
45 |
+
ego_speed_init_low: float
|
46 |
+
ego_speed_init_high: float
|
47 |
+
ego_acceleration_mean_low: float
|
48 |
+
ego_acceleration_mean_high: float
|
49 |
+
ego_acceleration_std: float
|
50 |
+
ego_length: float
|
51 |
+
ego_width: float
|
52 |
+
dt: float
|
53 |
+
fast_speed: float
|
54 |
+
slow_speed: float
|
55 |
+
p_change_pace: float
|
56 |
+
proportion_fast: float
|
57 |
+
perception_noise_std: float
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def from_config(cfg: Config):
|
61 |
+
return RandomSceneParams(
|
62 |
+
batch_size=cfg.batch_size,
|
63 |
+
sample_times=cfg.sample_times,
|
64 |
+
time_scene=cfg.time_scene,
|
65 |
+
ego_ref_speed=cfg.ego_ref_speed,
|
66 |
+
ego_speed_init_low=cfg.ego_speed_init_low,
|
67 |
+
ego_speed_init_high=cfg.ego_speed_init_high,
|
68 |
+
ego_acceleration_mean_low=cfg.ego_acceleration_mean_low,
|
69 |
+
ego_acceleration_mean_high=cfg.ego_acceleration_mean_high,
|
70 |
+
ego_acceleration_std=cfg.ego_acceleration_std,
|
71 |
+
ego_length=cfg.ego_length,
|
72 |
+
ego_width=cfg.ego_width,
|
73 |
+
dt=cfg.dt,
|
74 |
+
fast_speed=cfg.fast_speed,
|
75 |
+
slow_speed=cfg.slow_speed,
|
76 |
+
p_change_pace=cfg.p_change_pace,
|
77 |
+
proportion_fast=cfg.proportion_fast,
|
78 |
+
perception_noise_std=cfg.perception_noise_std,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class RandomScene:
|
83 |
+
"""
|
84 |
+
Batched scenes with one vehicle at constant velocity and one random pedestrian. Utility functions to draw the scene and compute risk factors (time to collision etc...)
|
85 |
+
|
86 |
+
Args:
|
87 |
+
params: dataclass containing the necessary parameters
|
88 |
+
is_torch: set to True to produce Tensor batches and to False to produce numpy arrays
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
params: RandomSceneParams,
|
94 |
+
is_torch: bool = False,
|
95 |
+
) -> None:
|
96 |
+
|
97 |
+
self._is_torch = is_torch
|
98 |
+
self._batch_size = params.batch_size
|
99 |
+
self._fast_speed = params.fast_speed
|
100 |
+
self._slow_speed = params.slow_speed
|
101 |
+
self._p_change_pace = params.p_change_pace
|
102 |
+
self._proportion_fast = params.proportion_fast
|
103 |
+
self.dt = params.dt
|
104 |
+
self.sample_times = params.sample_times
|
105 |
+
self.ego_ref_speed = params.ego_ref_speed
|
106 |
+
self._ego_speed_init_low = params.ego_speed_init_low
|
107 |
+
self._ego_speed_init_high = params.ego_speed_init_high
|
108 |
+
self._ego_acceleration_mean_low = params.ego_acceleration_mean_low
|
109 |
+
self._ego_acceleration_mean_high = params.ego_acceleration_mean_high
|
110 |
+
self._ego_acceleration_std = params.ego_acceleration_std
|
111 |
+
self.perception_noise_std = params.perception_noise_std
|
112 |
+
self.road_length = (
|
113 |
+
params.ego_ref_speed + params.fast_speed
|
114 |
+
) * params.time_scene
|
115 |
+
self.time_scene = params.time_scene
|
116 |
+
self.lane_width = 3
|
117 |
+
self.sidewalks_width = 1.5
|
118 |
+
self.road_width = 2 * self.lane_width + 2 * self.sidewalks_width
|
119 |
+
self.bottom = -self.lane_width / 2 - self.sidewalks_width
|
120 |
+
self.top = 3 * self.lane_width / 2 + self.sidewalks_width
|
121 |
+
self.ego_width = 1.75
|
122 |
+
self.ego_length = 4
|
123 |
+
self.current_time = 0
|
124 |
+
|
125 |
+
if self._is_torch:
|
126 |
+
pedestrians_x = (
|
127 |
+
torch.rand(params.batch_size, 1)
|
128 |
+
* (self.road_length - self.ego_length / 2)
|
129 |
+
+ self.ego_length / 2
|
130 |
+
)
|
131 |
+
pedestrians_y = (
|
132 |
+
torch.rand(params.batch_size, 1) * (self.top - self.bottom)
|
133 |
+
+ self.bottom
|
134 |
+
)
|
135 |
+
self._pedestrians_positions = torch.stack(
|
136 |
+
(pedestrians_x, pedestrians_y), -1
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
pedestrians_x = np.random.uniform(
|
140 |
+
low=self.ego_length / 2,
|
141 |
+
high=self.road_length,
|
142 |
+
size=(params.batch_size, 1),
|
143 |
+
)
|
144 |
+
pedestrians_y = np.random.uniform(
|
145 |
+
low=self.bottom, high=self.top, size=(params.batch_size, 1)
|
146 |
+
)
|
147 |
+
self._pedestrians_positions = np.stack((pedestrians_x, pedestrians_y), -1)
|
148 |
+
|
149 |
+
self.pedestrians = RandomPedestrians(
|
150 |
+
batch_size=self._batch_size,
|
151 |
+
dt=self.dt,
|
152 |
+
fast_speed=self._fast_speed,
|
153 |
+
slow_speed=self._slow_speed,
|
154 |
+
p_change_pace=self._p_change_pace,
|
155 |
+
proportion_fast=self._proportion_fast,
|
156 |
+
is_torch=self._is_torch,
|
157 |
+
)
|
158 |
+
self._set_pedestrians()
|
159 |
+
|
160 |
+
@property
|
161 |
+
def pedestrians_positions(self):
|
162 |
+
# relative_positions = self._pedestrians_positions/[[(self.road_length - self.ego_length / 2), (self.top - self.bottom)]] - [[self.ego_length / 2, self.bottom]]
|
163 |
+
return self._pedestrians_positions
|
164 |
+
|
165 |
+
def set_pedestrians_states(
|
166 |
+
self,
|
167 |
+
relative_pedestrians_positions: Union[torch.Tensor, np.ndarray],
|
168 |
+
pedestrians_angles: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
169 |
+
):
|
170 |
+
"""Force pedestrian initial states
|
171 |
+
|
172 |
+
Args:
|
173 |
+
relative_pedestrians_positions: Relative positions in the scene as percentage distance from left to right and from bottom to top
|
174 |
+
pedestrians_angles: Pedestrian heading angles in radiants
|
175 |
+
"""
|
176 |
+
if self._is_torch:
|
177 |
+
assert isinstance(relative_pedestrians_positions, torch.Tensor)
|
178 |
+
else:
|
179 |
+
assert isinstance(relative_pedestrians_positions, np.ndarray)
|
180 |
+
|
181 |
+
self._batch_size = relative_pedestrians_positions.shape[0]
|
182 |
+
if (0 > relative_pedestrians_positions).any() or (
|
183 |
+
relative_pedestrians_positions > 1
|
184 |
+
).any():
|
185 |
+
warnings.warn(
|
186 |
+
"Some of the given pedestrian initial positions are outside of the road range"
|
187 |
+
)
|
188 |
+
center_y = (self.top - self.bottom) * relative_pedestrians_positions[
|
189 |
+
:, :, 1
|
190 |
+
] + self.bottom
|
191 |
+
center_x = (
|
192 |
+
self.road_length - self.ego_length / 2
|
193 |
+
) * relative_pedestrians_positions[:, :, 0] + self.ego_length / 2
|
194 |
+
if self._is_torch:
|
195 |
+
pedestrians_positions = torch.stack([center_x, center_y], -1)
|
196 |
+
else:
|
197 |
+
pedestrians_positions = np.stack([center_x, center_y], -1)
|
198 |
+
|
199 |
+
self.pedestrians = RandomPedestrians(
|
200 |
+
batch_size=self._batch_size,
|
201 |
+
dt=self.dt,
|
202 |
+
fast_speed=self._fast_speed,
|
203 |
+
slow_speed=self._slow_speed,
|
204 |
+
p_change_pace=self._p_change_pace,
|
205 |
+
proportion_fast=self._proportion_fast,
|
206 |
+
is_torch=self._is_torch,
|
207 |
+
)
|
208 |
+
self._pedestrians_positions = pedestrians_positions
|
209 |
+
if pedestrians_angles is not None:
|
210 |
+
self.pedestrians.angle = pedestrians_angles
|
211 |
+
self._set_pedestrians()
|
212 |
+
|
213 |
+
def _set_pedestrians(self):
|
214 |
+
self.pedestrians_trajectories = self.sample_pedestrians_trajectories(
|
215 |
+
self.sample_times
|
216 |
+
)
|
217 |
+
|
218 |
+
self.final_pedestrians_positions = self.pedestrians_trajectories[:, :, -1]
|
219 |
+
|
220 |
+
def get_ego_ref_trajectory(self, time_sequence: list):
|
221 |
+
"""
|
222 |
+
Returns only one ego reference trajectory and not a batch because it is always the same.
|
223 |
+
Args:
|
224 |
+
time_sequence: the time points at which to get the positions
|
225 |
+
"""
|
226 |
+
out = np.array([[[[t * self.ego_ref_speed, 0] for t in time_sequence]]])
|
227 |
+
if self._is_torch:
|
228 |
+
return torch.from_numpy(out.astype("float32"))
|
229 |
+
else:
|
230 |
+
return out
|
231 |
+
|
232 |
+
def get_pedestrians_velocities(self):
|
233 |
+
"""
|
234 |
+
Returns the batch of mean pedestrian velocities between their positions and their final positions.
|
235 |
+
"""
|
236 |
+
return (self.final_pedestrians_positions - self._pedestrians_positions)[
|
237 |
+
:, None
|
238 |
+
] / self.time_scene
|
239 |
+
|
240 |
+
def get_ego_ref_velocity(self):
|
241 |
+
"""
|
242 |
+
Returns the reference ego velocity.
|
243 |
+
"""
|
244 |
+
if self._is_torch:
|
245 |
+
return torch.from_numpy(
|
246 |
+
np.array([[[[self.ego_ref_speed, 0]]]], dtype="float32")
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
return np.array([[[[self.ego_ref_speed, 0]]]])
|
250 |
+
|
251 |
+
def get_ego_ref_position(self):
|
252 |
+
"""
|
253 |
+
Returns the current reference ego position (at set time self.current_time)
|
254 |
+
"""
|
255 |
+
if self._is_torch:
|
256 |
+
return torch.from_numpy(
|
257 |
+
np.array(
|
258 |
+
[[[[self.ego_ref_speed * self.current_time, 0]]]], dtype="float32"
|
259 |
+
)
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
return np.array([[[[self.ego_ref_speed * self.current_time, 0]]]])
|
263 |
+
|
264 |
+
def set_current_time(self, time: float):
|
265 |
+
"""
|
266 |
+
Set the current time of the scene.
|
267 |
+
Args:
|
268 |
+
time : The current time to set. It should be between 0 and 1
|
269 |
+
"""
|
270 |
+
assert 0 <= time <= self.time_scene
|
271 |
+
self.current_time = time
|
272 |
+
|
273 |
+
def sample_ego_velocities(self, time_sequence: list):
|
274 |
+
"""
|
275 |
+
Get ego velocity trajectories following the ego's acceleration distribution and the initial
|
276 |
+
velocity distribution.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
time_sequence: a list of time points at which to sample the trajectory positions.
|
280 |
+
Returns:
|
281 |
+
batch of sequence of velocities of shape (batch_size, 1, len(time_sequence), 2)
|
282 |
+
"""
|
283 |
+
vel_traj = []
|
284 |
+
# uniform sampling of acceleration_mean between self._ego_acceleration_mean_low and
|
285 |
+
# self._ego_acceleration_mean_high
|
286 |
+
acceleration_mean = np.random.rand(self._batch_size, 2) * np.array(
|
287 |
+
[
|
288 |
+
self._ego_acceleration_mean_high - self._ego_acceleration_mean_low,
|
289 |
+
0.0,
|
290 |
+
]
|
291 |
+
) + np.array([self._ego_acceleration_mean_low, 0.0])
|
292 |
+
t_prev = 0
|
293 |
+
# uniform sampling of initial velocity between self._ego_speed_init_low and
|
294 |
+
# self._ego_speed_init_high
|
295 |
+
vel_prev = np.random.rand(self._batch_size, 2) * np.array(
|
296 |
+
[self._ego_speed_init_high - self._ego_speed_init_low, 0.0]
|
297 |
+
) + np.array([self._ego_speed_init_low, 0.0])
|
298 |
+
for t in time_sequence:
|
299 |
+
# integrate accelerations once to get velocities
|
300 |
+
acceleration = acceleration_mean + np.random.randn(
|
301 |
+
self._batch_size, 2
|
302 |
+
) * np.array([self._ego_acceleration_std, 0.0])
|
303 |
+
vel_prev = vel_prev + acceleration * (t - t_prev)
|
304 |
+
t_prev = t
|
305 |
+
vel_traj.append(vel_prev)
|
306 |
+
vel_traj = np.stack(vel_traj, 1)
|
307 |
+
if self._is_torch:
|
308 |
+
vel_traj = torch.from_numpy(vel_traj.astype("float32"))
|
309 |
+
return vel_traj[:, None]
|
310 |
+
|
311 |
+
def sample_ego_trajectories(self, time_sequence: list):
|
312 |
+
"""
|
313 |
+
Get ego trajectories following the ego's acceleration distribution and the initial velocity
|
314 |
+
distribution.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
time_sequence: a list of time points at which to sample the trajectory positions.
|
318 |
+
Returns:
|
319 |
+
batch of sequence of positions of shape (batch_size, len(time_sequence), 2)
|
320 |
+
"""
|
321 |
+
vel_traj = self.sample_ego_velocities(time_sequence)
|
322 |
+
traj = []
|
323 |
+
t_prev = 0
|
324 |
+
pos_prev = np.array([[0, 0]], dtype="float32")
|
325 |
+
if self._is_torch:
|
326 |
+
pos_prev = torch.from_numpy(pos_prev)
|
327 |
+
for idx, t in enumerate(time_sequence):
|
328 |
+
# integrate velocities once to get positions
|
329 |
+
vel = vel_traj[:, :, idx, :]
|
330 |
+
pos_prev = pos_prev + vel * (t - t_prev)
|
331 |
+
t_prev = t
|
332 |
+
traj.append(pos_prev)
|
333 |
+
if self._is_torch:
|
334 |
+
return torch.stack(traj, -2)
|
335 |
+
else:
|
336 |
+
return np.stack(traj, -2)
|
337 |
+
|
338 |
+
def sample_pedestrians_trajectories(self, time_sequence: list):
|
339 |
+
"""
|
340 |
+
Produce pedestrian trajectories following the pedestrian behavior distribution
|
341 |
+
(it is resampled, the final position will not match self.final_pedestrians_positions)
|
342 |
+
Args:
|
343 |
+
time_sequence: a list of time points at which to sample the trajectory positions.
|
344 |
+
Returns:
|
345 |
+
batch of sequence of positions of shape (batch_size, len(time_sequence), 2)
|
346 |
+
"""
|
347 |
+
traj = []
|
348 |
+
t_prev = 0
|
349 |
+
pos_prev = self.pedestrians_positions
|
350 |
+
for t in time_sequence:
|
351 |
+
pos_prev = (
|
352 |
+
pos_prev
|
353 |
+
+ self.pedestrians.get_final_position(t - t_prev)
|
354 |
+
- self.pedestrians.position
|
355 |
+
)
|
356 |
+
t_prev = t
|
357 |
+
traj.append(pos_prev)
|
358 |
+
if self._is_torch:
|
359 |
+
traj = torch.stack(traj, 2)
|
360 |
+
return traj + torch.randn_like(traj) * self.perception_noise_std
|
361 |
+
else:
|
362 |
+
traj = np.stack(traj, 2)
|
363 |
+
return traj + np.random.randn(*traj.shape) * self.perception_noise_std
|
364 |
+
|
365 |
+
def get_pedestrians_trajectories(self):
|
366 |
+
"""
|
367 |
+
Returns the batch of pedestrian trajectories sampled every dt.
|
368 |
+
"""
|
369 |
+
return self.pedestrians_trajectories
|
370 |
+
|
371 |
+
def get_pedestrian_trajectory(self, ind: int, time_sequence: list = None):
|
372 |
+
"""
|
373 |
+
Returns one pedestrian trajectory of index ind sampled at times set in time_sequence.
|
374 |
+
Args:
|
375 |
+
ind: index of the pedestrian in the batch.
|
376 |
+
time_sequence: a list of time points at which to sample the trajectory positions.
|
377 |
+
Returns:
|
378 |
+
A pedestrian trajectory of shape (len(time_sequence), 2)
|
379 |
+
"""
|
380 |
+
len_traj = len(self.sample_times)
|
381 |
+
if self._is_torch:
|
382 |
+
ped_traj = torch_linspace(
|
383 |
+
self.pedestrians_positions[ind],
|
384 |
+
self.final_pedestrians_positions[ind],
|
385 |
+
len_traj,
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
ped_traj = np.linspace(
|
389 |
+
self.pedestrians_positions[ind],
|
390 |
+
self.final_pedestrians_positions[ind],
|
391 |
+
len_traj,
|
392 |
+
)
|
393 |
+
|
394 |
+
if time_sequence is not None:
|
395 |
+
n_steps = [int(t / self.dt) for t in time_sequence]
|
396 |
+
else:
|
397 |
+
n_steps = range(int(self.time_scene / self.dt))
|
398 |
+
return ped_traj[n_steps]
|
399 |
+
|
400 |
+
|
401 |
+
class SceneDataset(Dataset):
|
402 |
+
"""
|
403 |
+
Dataset of scenes with one vehicle at constant velocity and one random pedestrian.
|
404 |
+
The scenes are randomly generated so the distribution can be sampled at each batch or pre-fetched.
|
405 |
+
|
406 |
+
Args:
|
407 |
+
len: int number of scenes per epoch
|
408 |
+
params: dataclass defining all the necessary parameters
|
409 |
+
pre_fetch: set to True to fetch the whole dataset at initialization
|
410 |
+
"""
|
411 |
+
|
412 |
+
def __init__(
|
413 |
+
self,
|
414 |
+
len: int,
|
415 |
+
params: RandomSceneParams,
|
416 |
+
pre_fetch: bool = True,
|
417 |
+
) -> None:
|
418 |
+
super().__init__()
|
419 |
+
self._pre_fetch = pre_fetch
|
420 |
+
self._len = len
|
421 |
+
self._sample_times = params.sample_times
|
422 |
+
self.params = copy.deepcopy(params)
|
423 |
+
params.batch_size = len
|
424 |
+
if self._pre_fetch:
|
425 |
+
self.scene_set = RandomScene(
|
426 |
+
params, is_torch=True
|
427 |
+
).sample_pedestrians_trajectories(self._sample_times)
|
428 |
+
|
429 |
+
def __len__(self) -> int:
|
430 |
+
return self._len
|
431 |
+
|
432 |
+
# This is a hack, get item only returns the index so that the collate_fn can handle making the batch without looping on RandomScene creation.
|
433 |
+
def __getitem__(self, index: int) -> Tensor:
|
434 |
+
return index
|
435 |
+
|
436 |
+
def collate_fn(self, index_list: list) -> Tensor:
|
437 |
+
if self._pre_fetch:
|
438 |
+
return self.scene_set[torch.from_numpy(np.array(index_list))]
|
439 |
+
else:
|
440 |
+
self.params.batch_size = len(index_list)
|
441 |
+
return RandomScene(
|
442 |
+
self.params,
|
443 |
+
is_torch=True,
|
444 |
+
).sample_pedestrians_trajectories(self._sample_times)
|
445 |
+
|
446 |
+
|
447 |
+
# Call this function to create a dataset as a .npy file that can be loaded as a numpy array with np.load(file_name.npy)
|
448 |
+
def save_dataset(file_path: str, size: int, config: Config):
|
449 |
+
"""
|
450 |
+
Save a dataset at file_path using the configuration.
|
451 |
+
Args:
|
452 |
+
file_path: Where to save the dataset
|
453 |
+
size: Number of samples to save
|
454 |
+
config: Configuration to use for the dataset generation
|
455 |
+
"""
|
456 |
+
dir_path = os.path.dirname(file_path)
|
457 |
+
config_path = os.path.join(dir_path, "config.py")
|
458 |
+
config = copy.deepcopy(config)
|
459 |
+
config.batch_size = size
|
460 |
+
params = RandomSceneParams.from_config(config)
|
461 |
+
scene = RandomScene(
|
462 |
+
params,
|
463 |
+
is_torch=False,
|
464 |
+
)
|
465 |
+
data_pedestrian = scene.sample_pedestrians_trajectories(config.sample_times)
|
466 |
+
data_ego = scene.sample_ego_trajectories(config.sample_times)
|
467 |
+
data = np.stack([data_pedestrian, data_ego], 0)
|
468 |
+
np.save(file_path, data)
|
469 |
+
# Cannot use config.dump here because it is buggy and does not work if config was not loaded from a file.
|
470 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
471 |
+
f.write(config.pretty_text)
|
472 |
+
|
473 |
+
|
474 |
+
def load_create_dataset(
|
475 |
+
config: Config,
|
476 |
+
base_dir=None,
|
477 |
+
) -> List:
|
478 |
+
"""
|
479 |
+
Load the dataset described by its config if it exists or create one.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
config: Configuration to use for the dataset
|
483 |
+
base_dir: Where to look for the dataset or to save it.
|
484 |
+
"""
|
485 |
+
|
486 |
+
if base_dir is None:
|
487 |
+
base_dir = os.path.join(
|
488 |
+
os.path.dirname(os.path.realpath(__file__)), "..", "..", "data"
|
489 |
+
)
|
490 |
+
found = False
|
491 |
+
dataset_out = []
|
492 |
+
i = 0
|
493 |
+
dir_path = os.path.join(base_dir, f"scene_dataset_{i:03d}")
|
494 |
+
while os.path.exists(dir_path):
|
495 |
+
config_path = os.path.join(dir_path, "config.py")
|
496 |
+
if os.path.exists(config_path):
|
497 |
+
config_check = Config.fromfile(config_path)
|
498 |
+
if config_check.dataset_parameters == config.dataset_parameters:
|
499 |
+
found = True
|
500 |
+
break
|
501 |
+
else:
|
502 |
+
warnings.warn(
|
503 |
+
f"Dataset directory {dir_path} exists but doesn't contain a config file. Cannot use it."
|
504 |
+
)
|
505 |
+
i += 1
|
506 |
+
dir_path = os.path.join(base_dir, f"scene_dataset_{i:03d}")
|
507 |
+
|
508 |
+
if not found:
|
509 |
+
print(f"Dataset not found, creating a new one.")
|
510 |
+
os.makedirs(dir_path)
|
511 |
+
for dataset in config.datasets:
|
512 |
+
dataset_name = f"scene_dataset_{dataset}.npy"
|
513 |
+
dataset_path = os.path.join(dir_path, dataset_name)
|
514 |
+
save_dataset(dataset_path, config.datasets_sizes[dataset], config)
|
515 |
+
if found:
|
516 |
+
print(f"Loading existing dataset at {dir_path}.")
|
517 |
+
|
518 |
+
for dataset in config.datasets:
|
519 |
+
dataset_path = os.path.join(dir_path, f"scene_dataset_{dataset}.npy")
|
520 |
+
dataset_out.append(torch.from_numpy(np.load(dataset_path).astype("float32")))
|
521 |
+
|
522 |
+
return dataset_out
|
risk_biased/scene_dataset/scene_plotter.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from matplotlib.axes import Axes
|
5 |
+
from matplotlib.collections import PatchCollection
|
6 |
+
from matplotlib.lines import Line2D
|
7 |
+
from matplotlib.patches import Rectangle, Ellipse
|
8 |
+
import matplotlib
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
|
13 |
+
|
14 |
+
|
15 |
+
class ScenePlotter:
|
16 |
+
"""
|
17 |
+
This class defines plotting functions that takes in a scene and an optional axes to plot road agents and trajectories.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
scene: The scene to use for plotting
|
21 |
+
ax: Matplotlib axes in which the drawing is made
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, scene: RandomScene, ax: Optional[Axes] = None) -> None:
|
25 |
+
self.scene = scene
|
26 |
+
if ax is None:
|
27 |
+
self.ax = plt.subplot()
|
28 |
+
else:
|
29 |
+
self.ax = ax
|
30 |
+
self._sidewalks_boxes = PatchCollection(
|
31 |
+
[
|
32 |
+
Rectangle(
|
33 |
+
xy=[-scene.ego_length, scene.bottom],
|
34 |
+
height=scene.sidewalks_width,
|
35 |
+
width=scene.road_length + scene.ego_length,
|
36 |
+
),
|
37 |
+
Rectangle(
|
38 |
+
xy=[-scene.ego_length, 3 * scene.lane_width / 2],
|
39 |
+
height=scene.sidewalks_width,
|
40 |
+
width=scene.road_length + scene.ego_length,
|
41 |
+
),
|
42 |
+
],
|
43 |
+
facecolor="gray",
|
44 |
+
alpha=0.3,
|
45 |
+
edgecolor="black",
|
46 |
+
)
|
47 |
+
self._center_line = Line2D(
|
48 |
+
[-scene.ego_length / 2, scene.road_length],
|
49 |
+
[scene.lane_width / 2, scene.lane_width / 2],
|
50 |
+
linewidth=4,
|
51 |
+
color="black",
|
52 |
+
dashes=[10, 5],
|
53 |
+
)
|
54 |
+
|
55 |
+
self._set_agent_patches()
|
56 |
+
self._set_agent_paths()
|
57 |
+
self.ax.set_aspect("equal")
|
58 |
+
|
59 |
+
def _set_current_time(self, time: float):
|
60 |
+
"""
|
61 |
+
Set the current time to draw the agents at the proper time along the trajectory.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
time: the present time in second
|
65 |
+
"""
|
66 |
+
self.scene.set_current_time(time)
|
67 |
+
self._set_agent_patches()
|
68 |
+
|
69 |
+
def _set_agent_paths(self):
|
70 |
+
"""
|
71 |
+
Defines path as lines.
|
72 |
+
"""
|
73 |
+
self._ego_path = Line2D(
|
74 |
+
[0, self.scene.ego_ref_speed * self.scene.time_scene],
|
75 |
+
[0, 0],
|
76 |
+
linewidth=2,
|
77 |
+
color="red",
|
78 |
+
dashes=[4, 4],
|
79 |
+
alpha=0.3,
|
80 |
+
)
|
81 |
+
|
82 |
+
self._pedestrian_path = [
|
83 |
+
[
|
84 |
+
Line2D(
|
85 |
+
[init[agent, 0], final[agent, 0]],
|
86 |
+
[init[agent, 1], final[agent, 1]],
|
87 |
+
linewidth=2,
|
88 |
+
dashes=[4, 4],
|
89 |
+
alpha=0.3,
|
90 |
+
)
|
91 |
+
for (init, final) in zip(
|
92 |
+
self.scene.pedestrians_positions,
|
93 |
+
self.scene.final_pedestrians_positions,
|
94 |
+
)
|
95 |
+
]
|
96 |
+
for agent in range(self.scene.pedestrians_positions.shape[1])
|
97 |
+
]
|
98 |
+
|
99 |
+
def _set_agent_patches(self):
|
100 |
+
"""
|
101 |
+
Set the agent patches at their current position in the scene.
|
102 |
+
"""
|
103 |
+
current_step = int(round(self.scene.current_time / self.scene.dt))
|
104 |
+
self._ego_box = Rectangle(
|
105 |
+
xy=(
|
106 |
+
-self.scene.ego_length / 2
|
107 |
+
+ self.scene.ego_ref_speed * self.scene.current_time,
|
108 |
+
-self.scene.ego_width / 2,
|
109 |
+
),
|
110 |
+
height=self.scene.ego_width,
|
111 |
+
width=self.scene.ego_length,
|
112 |
+
fill=True,
|
113 |
+
facecolor="red",
|
114 |
+
alpha=0.4,
|
115 |
+
edgecolor="black",
|
116 |
+
)
|
117 |
+
self._pedestrian_patches = [
|
118 |
+
[
|
119 |
+
Ellipse(
|
120 |
+
xy=xy,
|
121 |
+
width=1,
|
122 |
+
height=0.5,
|
123 |
+
angle=angle * 180 / np.pi + 90,
|
124 |
+
facecolor="blue",
|
125 |
+
alpha=0.4,
|
126 |
+
edgecolor="black",
|
127 |
+
)
|
128 |
+
for xy, angle in zip(
|
129 |
+
self.scene.pedestrians_trajectories[:, agent, current_step],
|
130 |
+
self.scene.pedestrians.angle[:, agent],
|
131 |
+
)
|
132 |
+
]
|
133 |
+
for agent in range(self.scene.pedestrians_trajectories.shape[1])
|
134 |
+
]
|
135 |
+
|
136 |
+
def plot_road(self) -> None:
|
137 |
+
"""
|
138 |
+
Plot the road as a two lanes, two sidewalks in straight lines with the ego vehicle. Plot is made in given ax.
|
139 |
+
"""
|
140 |
+
self.ax.add_collection(self._sidewalks_boxes)
|
141 |
+
self.ax.add_patch(self._ego_box)
|
142 |
+
self.ax.add_line(self._center_line)
|
143 |
+
self.ax.add_line(self._ego_path)
|
144 |
+
self.rescale()
|
145 |
+
|
146 |
+
def draw_scene(self, index: int, time=None, prediction=None) -> None:
|
147 |
+
"""
|
148 |
+
Plot the scene of given index (road, ego vehicle with its path, pedestrian with its path)
|
149 |
+
Args:
|
150 |
+
index: index of the pedestrian in the batch
|
151 |
+
time: set current time to this value if not None
|
152 |
+
prediction: draw this instead of the actual future if not None
|
153 |
+
"""
|
154 |
+
if time is not None:
|
155 |
+
self._set_current_time(time)
|
156 |
+
self.plot_road()
|
157 |
+
for agent_patch in self._pedestrian_patches:
|
158 |
+
self.ax.add_patch(agent_patch[index])
|
159 |
+
for agent_patch in self._pedestrian_path:
|
160 |
+
self.ax.add_line(agent_patch[index])
|
161 |
+
if prediction is not None:
|
162 |
+
self.draw_trajectory(prediction)
|
163 |
+
|
164 |
+
def rescale(self):
|
165 |
+
"""
|
166 |
+
Set the x and y limits to the road shape with a margin.
|
167 |
+
"""
|
168 |
+
self.ax.set_xlim(
|
169 |
+
left=-2 * self.scene.ego_length,
|
170 |
+
right=self.scene.road_length + self.scene.ego_length,
|
171 |
+
)
|
172 |
+
self.ax.set_ylim(
|
173 |
+
bottom=self.scene.bottom - self.scene.lane_width,
|
174 |
+
top=2 * self.scene.lane_width + 2 * self.scene.sidewalks_width,
|
175 |
+
)
|
176 |
+
|
177 |
+
def draw_trajectory(self, prediction, color="b") -> None:
|
178 |
+
"""
|
179 |
+
Plot the given prediction in the scene.
|
180 |
+
"""
|
181 |
+
self.ax.scatter(prediction[..., 0], prediction[..., 1], color=color, alpha=0.3)
|
182 |
+
|
183 |
+
def draw_all_trajectories(
|
184 |
+
self,
|
185 |
+
prediction: np.ndarray,
|
186 |
+
color=None,
|
187 |
+
color_value: np.ndarray = None,
|
188 |
+
alpha: float = 0.05,
|
189 |
+
label: str = "trajectory",
|
190 |
+
) -> None:
|
191 |
+
"""
|
192 |
+
Plot all the given predictions in the scene
|
193 |
+
Args:
|
194 |
+
prediction : (batch, n_agents, time, 2) batch of trajectories
|
195 |
+
color: regular color name
|
196 |
+
color_value : (batch) Optional batch of values for coloring from green to red
|
197 |
+
"""
|
198 |
+
|
199 |
+
if color_value is not None:
|
200 |
+
min = color_value.min()
|
201 |
+
max = color_value.max()
|
202 |
+
color_value = 0.9 * (color_value - min) / (max - min)
|
203 |
+
for agent in range(prediction.shape[1]):
|
204 |
+
for traj, val in zip(prediction[:, agent], color_value[:, agent]):
|
205 |
+
color = (val, 1 - val, 0.1)
|
206 |
+
self.ax.plot(
|
207 |
+
traj[:, 0], traj[:, 1], color=color, alpha=alpha, label=label
|
208 |
+
)
|
209 |
+
self.ax.scatter(traj[-1, 0], traj[-1, 1], color=color, alpha=alpha)
|
210 |
+
cmap = matplotlib.colors.ListedColormap(
|
211 |
+
np.linspace(
|
212 |
+
[color_value.min(), 1 - color_value.min(), 0.1],
|
213 |
+
[color_value.max(), 1 - color_value.max(), 0.1],
|
214 |
+
128,
|
215 |
+
)
|
216 |
+
)
|
217 |
+
norm = matplotlib.colors.Normalize(vmin=min, vmax=max, clip=True)
|
218 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
219 |
+
plt.colorbar(sm, label="TTC cost")
|
220 |
+
else:
|
221 |
+
for agent in range(prediction.shape[1]):
|
222 |
+
for traj in prediction:
|
223 |
+
self.ax.plot(
|
224 |
+
traj[agent, :, 0],
|
225 |
+
traj[agent, :, 1],
|
226 |
+
color=color,
|
227 |
+
alpha=alpha,
|
228 |
+
label=label,
|
229 |
+
)
|
230 |
+
self.ax.scatter(
|
231 |
+
prediction[:, agent, -1, 0],
|
232 |
+
prediction[:, agent, -1, 1],
|
233 |
+
color=color,
|
234 |
+
alpha=alpha,
|
235 |
+
)
|
236 |
+
|
237 |
+
def draw_legend(self):
|
238 |
+
"""Draw legend without repeats and without transparency."""
|
239 |
+
|
240 |
+
handles, labels = self.ax.get_legend_handles_labels()
|
241 |
+
i = np.arange(len(labels))
|
242 |
+
filter = np.array([])
|
243 |
+
unique_labels = list(set(labels))
|
244 |
+
for ul in unique_labels:
|
245 |
+
filter = np.append(filter, [i[np.array(labels) == ul][0]])
|
246 |
+
filtered_handles = []
|
247 |
+
for f in filter:
|
248 |
+
handles[int(f)].set_alpha(1)
|
249 |
+
filtered_handles.append(handles[int(f)])
|
250 |
+
filtered_labels = [labels[int(f)] for f in filter]
|
251 |
+
self.ax.legend(filtered_handles, filtered_labels)
|
252 |
+
|
253 |
+
|
254 |
+
# Draw a random scene
|
255 |
+
if __name__ == "__main__":
|
256 |
+
from risk_biased.utils.config_argparse import config_argparse
|
257 |
+
|
258 |
+
working_dir = os.path.dirname(os.path.realpath(__file__))
|
259 |
+
config_path = os.path.join(
|
260 |
+
working_dir, "..", "..", "risk_biased", "config", "learning_config.py"
|
261 |
+
)
|
262 |
+
config = config_argparse(config_path)
|
263 |
+
n_samples = 100
|
264 |
+
|
265 |
+
scene_params = RandomSceneParams.from_config(config)
|
266 |
+
scene_params.batch_size = n_samples
|
267 |
+
scene = RandomScene(
|
268 |
+
scene_params,
|
269 |
+
is_torch=False,
|
270 |
+
)
|
271 |
+
|
272 |
+
plotter = ScenePlotter(scene)
|
273 |
+
|
274 |
+
plotter.draw_scene(0, time=1)
|
275 |
+
plt.tight_layout()
|
276 |
+
plt.show()
|
risk_biased/utils/__init__.py
ADDED
File without changes
|
risk_biased/utils/callbacks.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
from mmcv import Config
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from pydantic import NoneBytes
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import wandb
|
11 |
+
|
12 |
+
from risk_biased.scene_dataset.loaders import SceneDataLoaders
|
13 |
+
from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
|
14 |
+
from risk_biased.scene_dataset.scene_plotter import ScenePlotter
|
15 |
+
from risk_biased.utils.cost import (
|
16 |
+
DistanceCostNumpy,
|
17 |
+
DistanceCostParams,
|
18 |
+
TTCCostNumpy,
|
19 |
+
TTCCostParams,
|
20 |
+
)
|
21 |
+
from risk_biased.utils.risk import get_risk_level_sampler
|
22 |
+
|
23 |
+
|
24 |
+
class SwitchTrainingModeCallback(pl.Callback):
|
25 |
+
"""
|
26 |
+
This callback switches between CVAE traning and biasing training for the biased_latent_cvae_model
|
27 |
+
Args:
|
28 |
+
switch_at_epoch: The number of epoch after which to make the switch. The CVAE is not trained anymore after that point.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, switch_at_epoch: int) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self._switch_at_epoch = switch_at_epoch
|
34 |
+
self._train_has_started = False
|
35 |
+
|
36 |
+
def on_train_start(
|
37 |
+
self, trainer: pl.Trainer, pl_module: pl.LightningModule
|
38 |
+
) -> None:
|
39 |
+
"""Store the optimizer list and set the trainer to the first optimizer."""
|
40 |
+
self._optimizers = trainer.optimizers
|
41 |
+
trainer.optimizers = [self._optimizers[0]]
|
42 |
+
self._train_has_started = True
|
43 |
+
|
44 |
+
def on_epoch_start(
|
45 |
+
self, trainer: pl.Trainer, pl_module: pl.LightningModule
|
46 |
+
) -> None:
|
47 |
+
"""
|
48 |
+
Check if the switch should be made and if so,
|
49 |
+
set the trainer on the second optimizer.
|
50 |
+
"""
|
51 |
+
if trainer.current_epoch == self._switch_at_epoch and self._train_has_started:
|
52 |
+
print("Switching to bias training.")
|
53 |
+
pl_module.set_training_mode("bias")
|
54 |
+
trainer.optimizers = [self._optimizers[1]]
|
55 |
+
|
56 |
+
|
57 |
+
def get_fast_slow_scenes(params: RandomSceneParams, n_samples: int):
|
58 |
+
"""Define and return two RandomScene objects, one initialized such that slow
|
59 |
+
pedestrians are safer and the other such that fast pedestrians are safer.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
params: dataclass containing the necessary parameters for a RandomScene object
|
63 |
+
n_samples: number of samples to draw in each scene
|
64 |
+
"""
|
65 |
+
params = copy.deepcopy(params)
|
66 |
+
params.batch_size = n_samples
|
67 |
+
scene_safe_slow = RandomScene(
|
68 |
+
params,
|
69 |
+
is_torch=False,
|
70 |
+
)
|
71 |
+
percent_right = 0.8
|
72 |
+
percent_top = 1.1
|
73 |
+
angle = 5 * np.pi / 4
|
74 |
+
positions = np.array([[[percent_right, percent_top]]] * n_samples)
|
75 |
+
angles = np.array([[angle]] * n_samples)
|
76 |
+
scene_safe_slow.set_pedestrians_states(positions, angles)
|
77 |
+
|
78 |
+
scene_safe_fast = RandomScene(
|
79 |
+
params,
|
80 |
+
is_torch=False,
|
81 |
+
)
|
82 |
+
percent_right = 0.8
|
83 |
+
percent_top = 0.6
|
84 |
+
angle = 5 * np.pi / 4
|
85 |
+
positions = np.array([[[percent_right, percent_top]]] * n_samples)
|
86 |
+
angles = np.array([[angle]] * n_samples)
|
87 |
+
scene_safe_fast.set_pedestrians_states(positions, angles)
|
88 |
+
return scene_safe_fast, scene_safe_slow
|
89 |
+
|
90 |
+
|
91 |
+
@dataclass
|
92 |
+
class DrawCallbackParams:
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
scene_params: dataclass parameters for the RandomScene
|
96 |
+
dist_cost_params: dataclass parameters for the DistanceCost
|
97 |
+
ttc_cost_params: dataclass parameters for the TTCCost
|
98 |
+
plot_interval_epoch: number of epochs between each plot drawing
|
99 |
+
histogram_interval_epoch: number of epochs between each histogram drawing
|
100 |
+
num_steps: number of time steps as defined in the config
|
101 |
+
num_steps_future: number of time steps in the future as defined in the config
|
102 |
+
risk_distribution: dict object describing a risk distribution
|
103 |
+
dt: time step size as defined in the config
|
104 |
+
"""
|
105 |
+
|
106 |
+
scene_params: RandomSceneParams
|
107 |
+
dist_cost_params: DistanceCostParams
|
108 |
+
ttc_cost_params: TTCCostParams
|
109 |
+
plot_interval_epoch: int
|
110 |
+
histogram_interval_epoch: int
|
111 |
+
num_steps: int
|
112 |
+
num_steps_future: int
|
113 |
+
risk_distribution: dict
|
114 |
+
dt: float
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def from_config(cfg: Config):
|
118 |
+
return DrawCallbackParams(
|
119 |
+
scene_params=RandomSceneParams.from_config(cfg),
|
120 |
+
dist_cost_params=DistanceCostParams.from_config(cfg),
|
121 |
+
ttc_cost_params=TTCCostParams.from_config(cfg),
|
122 |
+
plot_interval_epoch=cfg.plot_interval_epoch,
|
123 |
+
histogram_interval_epoch=cfg.histogram_interval_epoch,
|
124 |
+
num_steps=cfg.num_steps,
|
125 |
+
num_steps_future=cfg.num_steps_future,
|
126 |
+
risk_distribution=cfg.risk_distribution,
|
127 |
+
dt=cfg.dt,
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class HistogramCallback(pl.Callback):
|
132 |
+
"""Logs histograms of distances, distance cost and ttc cost for the data, the predictions at risk_level=0, the predictions at risk_level=1
|
133 |
+
Args:
|
134 |
+
params: dataclass defining the necessary parameters
|
135 |
+
n_samples: Number of samples to use for the histogram plot
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
params: DrawCallbackParams,
|
141 |
+
n_samples=1000,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
|
145 |
+
params.scene_params, n_samples
|
146 |
+
)
|
147 |
+
self.num_steps = params.num_steps
|
148 |
+
self.n_scenes = n_samples
|
149 |
+
self.sample_times = params.scene_params.sample_times
|
150 |
+
self.dist_cost_func = DistanceCostNumpy(params.dist_cost_params)
|
151 |
+
self.ttc_cost_func = TTCCostNumpy(params.ttc_cost_params)
|
152 |
+
self.histogram_interval_epoch = params.histogram_interval_epoch
|
153 |
+
|
154 |
+
self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(self.sample_times)
|
155 |
+
|
156 |
+
self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
|
157 |
+
|
158 |
+
def _log_scene(self, pl_module: pl.LightningModule, scene: RandomScene, name: str):
|
159 |
+
"""
|
160 |
+
Log in WandB three histogram for the given scene: One for the data, one for the predictions at risk_level=0 and one for the predictions at risk_level=1
|
161 |
+
Args:
|
162 |
+
pl_module: LightningModule object
|
163 |
+
scene: RandomScene object
|
164 |
+
name: name of the given scene
|
165 |
+
"""
|
166 |
+
ped_trajs = scene.get_pedestrians_trajectories()
|
167 |
+
device = pl_module.device
|
168 |
+
n_agents = ped_trajs.shape[1]
|
169 |
+
|
170 |
+
input_traj = ped_trajs[..., : self.num_steps, :]
|
171 |
+
|
172 |
+
normalized_input, offset = SceneDataLoaders.normalize_trajectory(
|
173 |
+
torch.from_numpy(input_traj.astype("float32")).contiguous().to(device)
|
174 |
+
)
|
175 |
+
mask_input = torch.ones_like(normalized_input[..., 0])
|
176 |
+
ego_history = (
|
177 |
+
torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32"))
|
178 |
+
.expand_as(normalized_input)
|
179 |
+
.contiguous()
|
180 |
+
.to(device)
|
181 |
+
)
|
182 |
+
ego_future = (
|
183 |
+
torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32"))
|
184 |
+
.expand(normalized_input.shape[0], n_agents, -1, -1)
|
185 |
+
.contiguous()
|
186 |
+
.to(device)
|
187 |
+
)
|
188 |
+
map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device)
|
189 |
+
mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device)
|
190 |
+
|
191 |
+
pred_riskier = (
|
192 |
+
pl_module.predict_step(
|
193 |
+
(
|
194 |
+
normalized_input,
|
195 |
+
mask_input,
|
196 |
+
map,
|
197 |
+
mask_map,
|
198 |
+
offset,
|
199 |
+
ego_history,
|
200 |
+
ego_future,
|
201 |
+
),
|
202 |
+
0,
|
203 |
+
risk_level=self._risk_sampler.get_highest_risk(
|
204 |
+
batch_size=self.n_scenes, device=device
|
205 |
+
)
|
206 |
+
.unsqueeze(1)
|
207 |
+
.repeat(1, n_agents),
|
208 |
+
)
|
209 |
+
.cpu()
|
210 |
+
.detach()
|
211 |
+
.numpy()
|
212 |
+
)
|
213 |
+
|
214 |
+
pred = (
|
215 |
+
pl_module.predict_step(
|
216 |
+
(
|
217 |
+
normalized_input,
|
218 |
+
mask_input,
|
219 |
+
map,
|
220 |
+
mask_map,
|
221 |
+
offset,
|
222 |
+
ego_history,
|
223 |
+
ego_future,
|
224 |
+
),
|
225 |
+
0,
|
226 |
+
risk_level=None,
|
227 |
+
)
|
228 |
+
.cpu()
|
229 |
+
.detach()
|
230 |
+
.numpy()
|
231 |
+
)
|
232 |
+
|
233 |
+
ped_trajs_pred = np.concatenate((input_traj, pred), axis=-2)
|
234 |
+
ped_trajs_pred_riskier = np.concatenate((input_traj, pred_riskier), axis=-2)
|
235 |
+
|
236 |
+
travel_distances = np.sqrt(
|
237 |
+
np.square(ped_trajs[..., -1, :] - ped_trajs[..., 0, :]).sum(-1)
|
238 |
+
)
|
239 |
+
|
240 |
+
dist_cost, dist = self.dist_cost_func(
|
241 |
+
self.ego_traj[..., self.num_steps :, :],
|
242 |
+
ped_trajs[..., self.num_steps :, :],
|
243 |
+
)
|
244 |
+
|
245 |
+
ttc_cost, (ttc, dist) = self.ttc_cost_func(
|
246 |
+
self.ego_traj[..., self.num_steps :, :],
|
247 |
+
ped_trajs[..., self.num_steps :, :],
|
248 |
+
scene.get_ego_ref_velocity(),
|
249 |
+
scene.get_pedestrians_velocities(),
|
250 |
+
)
|
251 |
+
|
252 |
+
travel_distances_pred = np.sqrt(
|
253 |
+
np.square(ped_trajs_pred[..., -1, :] - ped_trajs_pred[..., 0, :]).sum(-1)
|
254 |
+
)
|
255 |
+
dist_cost_pred, dist_pred = self.dist_cost_func(
|
256 |
+
self.ego_traj[..., self.num_steps :, :],
|
257 |
+
ped_trajs_pred[..., self.num_steps :, :],
|
258 |
+
)
|
259 |
+
sample_times = np.array(self.sample_times)
|
260 |
+
ped_velocities_pred = (
|
261 |
+
ped_trajs_pred[..., 1:, :] - ped_trajs_pred[..., :-1, :]
|
262 |
+
) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None])
|
263 |
+
ped_velocities_pred = np.concatenate(
|
264 |
+
(ped_velocities_pred[..., 0:1, :], ped_velocities_pred), -2
|
265 |
+
)
|
266 |
+
ttc_cost_pred, (ttc_pred, dist_pred) = self.ttc_cost_func(
|
267 |
+
self.ego_traj[..., self.num_steps :, :],
|
268 |
+
ped_trajs_pred[..., self.num_steps :, :],
|
269 |
+
scene.get_ego_ref_velocity(),
|
270 |
+
ped_velocities_pred[..., self.num_steps :, :],
|
271 |
+
)
|
272 |
+
|
273 |
+
travel_distances_pred_riskier = np.sqrt(
|
274 |
+
np.square(
|
275 |
+
ped_trajs_pred_riskier[..., -1, :] - ped_trajs_pred_riskier[..., 0, :]
|
276 |
+
).sum(-1)
|
277 |
+
)
|
278 |
+
|
279 |
+
dist_cost_pred_riskier, dist_pred_riskier = self.dist_cost_func(
|
280 |
+
self.ego_traj[..., self.num_steps :, :],
|
281 |
+
ped_trajs_pred_riskier[..., self.num_steps :, :],
|
282 |
+
)
|
283 |
+
sample_times = np.array(self.sample_times)
|
284 |
+
ped_velocities_pred_riskier = (
|
285 |
+
ped_trajs_pred_riskier[..., 1:, :] - ped_trajs_pred_riskier[..., :-1, :]
|
286 |
+
) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None])
|
287 |
+
ped_velocities_pred_riskier = np.concatenate(
|
288 |
+
(ped_velocities_pred_riskier[..., 0:1, :], ped_velocities_pred_riskier), -2
|
289 |
+
)
|
290 |
+
ttc_cost_pred_riskier, (ttc_pred, dist_pred_riskier) = self.ttc_cost_func(
|
291 |
+
self.ego_traj[..., self.num_steps :, :],
|
292 |
+
ped_trajs_pred_riskier[..., self.num_steps :, :],
|
293 |
+
scene.get_ego_ref_velocity(),
|
294 |
+
ped_velocities_pred_riskier[..., self.num_steps :, :],
|
295 |
+
)
|
296 |
+
data = [
|
297 |
+
[dist, dist_pred, dist_risk]
|
298 |
+
for (dist, dist_pred, dist_risk) in zip(
|
299 |
+
travel_distances.flatten(),
|
300 |
+
travel_distances_pred.flatten(),
|
301 |
+
travel_distances_pred_riskier.flatten(),
|
302 |
+
)
|
303 |
+
]
|
304 |
+
table_travel_distance = wandb.Table(
|
305 |
+
data=data,
|
306 |
+
columns=[
|
307 |
+
"Travel distance data " + name,
|
308 |
+
"Travel distance prediction " + name,
|
309 |
+
"Travel distance riskier " + name,
|
310 |
+
],
|
311 |
+
)
|
312 |
+
data = [
|
313 |
+
[cost, cost_pred, cost_risk]
|
314 |
+
for (cost, cost_pred, cost_risk) in zip(
|
315 |
+
dist_cost.flatten(),
|
316 |
+
dist_cost_pred.flatten(),
|
317 |
+
dist_cost_pred_riskier.flatten(),
|
318 |
+
)
|
319 |
+
]
|
320 |
+
table_distance_cost = wandb.Table(
|
321 |
+
data=data,
|
322 |
+
columns=[
|
323 |
+
"Distance cost data " + name,
|
324 |
+
"Distance cost prediction " + name,
|
325 |
+
"Distance cost riskier " + name,
|
326 |
+
],
|
327 |
+
)
|
328 |
+
data = [
|
329 |
+
[ttc, ttc_pred, ttc_risk]
|
330 |
+
for (ttc, ttc_pred, ttc_risk) in zip(
|
331 |
+
ttc_cost.flatten(),
|
332 |
+
ttc_cost_pred.flatten(),
|
333 |
+
ttc_cost_pred_riskier.flatten(),
|
334 |
+
)
|
335 |
+
]
|
336 |
+
table_ttc_cost = wandb.Table(
|
337 |
+
data=data,
|
338 |
+
columns=[
|
339 |
+
"TTC cost data " + name,
|
340 |
+
"TTC cost prediction " + name,
|
341 |
+
"TTC cost riskier " + name,
|
342 |
+
],
|
343 |
+
)
|
344 |
+
wandb.log(
|
345 |
+
{
|
346 |
+
"Travel distance data "
|
347 |
+
+ name: wandb.plot_table(
|
348 |
+
vega_spec_name="jmercat/histogram_01_bins",
|
349 |
+
data_table=table_travel_distance,
|
350 |
+
fields={
|
351 |
+
"value": "Travel distance data " + name,
|
352 |
+
"title": "Travel distance data " + name,
|
353 |
+
},
|
354 |
+
),
|
355 |
+
"Travel distance prediction "
|
356 |
+
+ name: wandb.plot_table(
|
357 |
+
vega_spec_name="jmercat/histogram_01_bins",
|
358 |
+
data_table=table_travel_distance,
|
359 |
+
fields={
|
360 |
+
"value": "Travel distance prediction " + name,
|
361 |
+
"title": "Travel distance prediction " + name,
|
362 |
+
},
|
363 |
+
),
|
364 |
+
"Travel distance riskier "
|
365 |
+
+ name: wandb.plot_table(
|
366 |
+
vega_spec_name="jmercat/histogram_01_bins",
|
367 |
+
data_table=table_travel_distance,
|
368 |
+
fields={
|
369 |
+
"value": "Travel distance riskier " + name,
|
370 |
+
"title": "Travel distance riskier " + name,
|
371 |
+
},
|
372 |
+
),
|
373 |
+
"Distance cost data "
|
374 |
+
+ name: wandb.plot_table(
|
375 |
+
vega_spec_name="jmercat/histogram_0025_bins",
|
376 |
+
data_table=table_distance_cost,
|
377 |
+
fields={
|
378 |
+
"value": "Distance cost data " + name,
|
379 |
+
"title": "Distance cost data " + name,
|
380 |
+
},
|
381 |
+
),
|
382 |
+
"Distance cost prediction "
|
383 |
+
+ name: wandb.plot_table(
|
384 |
+
vega_spec_name="jmercat/histogram_0025_bins",
|
385 |
+
data_table=table_distance_cost,
|
386 |
+
fields={
|
387 |
+
"value": "Distance cost prediction " + name,
|
388 |
+
"title": "Distance cost prediction " + name,
|
389 |
+
},
|
390 |
+
),
|
391 |
+
"Distance cost riskier "
|
392 |
+
+ name: wandb.plot_table(
|
393 |
+
vega_spec_name="jmercat/histogram_0025_bins",
|
394 |
+
data_table=table_distance_cost,
|
395 |
+
fields={
|
396 |
+
"value": "Distance cost riskier " + name,
|
397 |
+
"title": "Distance cost riskier " + name,
|
398 |
+
},
|
399 |
+
),
|
400 |
+
"TTC cost data "
|
401 |
+
+ name: wandb.plot_table(
|
402 |
+
vega_spec_name="jmercat/histogram_005_bins",
|
403 |
+
data_table=table_ttc_cost,
|
404 |
+
fields={
|
405 |
+
"value": "TTC cost data " + name,
|
406 |
+
"title": "TTC cost data " + name,
|
407 |
+
},
|
408 |
+
),
|
409 |
+
"TTC cost prediction "
|
410 |
+
+ name: wandb.plot_table(
|
411 |
+
vega_spec_name="jmercat/histogram_005_bins",
|
412 |
+
data_table=table_ttc_cost,
|
413 |
+
fields={
|
414 |
+
"value": "TTC cost prediction " + name,
|
415 |
+
"title": "TTC cost prediction " + name,
|
416 |
+
},
|
417 |
+
),
|
418 |
+
"TTC cost riskier "
|
419 |
+
+ name: wandb.plot_table(
|
420 |
+
vega_spec_name="jmercat/histogram_005_bins",
|
421 |
+
data_table=table_ttc_cost,
|
422 |
+
fields={
|
423 |
+
"value": "TTC cost riskier " + name,
|
424 |
+
"title": "TTC cost riskier " + name,
|
425 |
+
},
|
426 |
+
),
|
427 |
+
}
|
428 |
+
)
|
429 |
+
|
430 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
431 |
+
"""After a validation at the end of every histogram_interval_epoch,
|
432 |
+
log the histograms for two scenes: the safer fast scene and the safer slow scene.
|
433 |
+
"""
|
434 |
+
if (
|
435 |
+
trainer.current_epoch % self.histogram_interval_epoch
|
436 |
+
== self.histogram_interval_epoch - 1
|
437 |
+
):
|
438 |
+
self._log_scene(pl_module, self.scene_safe_fast, name="Safer fast")
|
439 |
+
self._log_scene(pl_module, self.scene_safe_slow, name="Safer slow")
|
440 |
+
|
441 |
+
|
442 |
+
class PlotTrajCallback(pl.Callback):
|
443 |
+
"""Plot trajectory samples for two scenes:
|
444 |
+
One that is safer for the slow pedestrians
|
445 |
+
One that is safer for the fast pedestrians
|
446 |
+
Samples of ground truth, prediction, and biased predictions are superposed.
|
447 |
+
Last positions are marked to visualize the clusters.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
params: dataclass containing the necessary parameters for a
|
451 |
+
n_samples: number of sample trajectories to draw
|
452 |
+
"""
|
453 |
+
|
454 |
+
def __init__(
|
455 |
+
self,
|
456 |
+
params: DrawCallbackParams,
|
457 |
+
n_samples: int = 1,
|
458 |
+
):
|
459 |
+
super().__init__()
|
460 |
+
self.n_samples = n_samples
|
461 |
+
self.num_steps = params.num_steps
|
462 |
+
self.dt = params.scene_params.dt
|
463 |
+
self.scene_params = params.scene_params
|
464 |
+
self.plot_interval_epoch = params.plot_interval_epoch
|
465 |
+
self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
|
466 |
+
params.scene_params, n_samples
|
467 |
+
)
|
468 |
+
self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(
|
469 |
+
params.scene_params.sample_times
|
470 |
+
)
|
471 |
+
self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
|
472 |
+
|
473 |
+
def _log_scene(self, epoch: int, pl_module, scene: RandomScene, name: str) -> None:
|
474 |
+
"""Add drawing of samples of prediction, biased prediction and ground truth in the scene.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
epoch: current epoch calling the log
|
478 |
+
pl_module: pytorch lightning module being trained
|
479 |
+
scene: scene to draw
|
480 |
+
name: name of the scene
|
481 |
+
"""
|
482 |
+
ped_trajs = scene.get_pedestrians_trajectories()
|
483 |
+
device = pl_module.device
|
484 |
+
n_agents = ped_trajs.shape[1]
|
485 |
+
|
486 |
+
input_traj = ped_trajs[..., : self.num_steps, :]
|
487 |
+
|
488 |
+
normalized_input, offset = SceneDataLoaders.normalize_trajectory(
|
489 |
+
torch.from_numpy(input_traj.astype("float32")).contiguous().to(device)
|
490 |
+
)
|
491 |
+
mask_input = torch.ones_like(normalized_input[..., 0])
|
492 |
+
ego_history = (
|
493 |
+
torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32"))
|
494 |
+
.expand_as(normalized_input)
|
495 |
+
.contiguous()
|
496 |
+
.to(device)
|
497 |
+
)
|
498 |
+
ego_future = (
|
499 |
+
torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32"))
|
500 |
+
.expand(normalized_input.shape[0], n_agents, -1, -1)
|
501 |
+
.contiguous()
|
502 |
+
.to(device)
|
503 |
+
)
|
504 |
+
map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device)
|
505 |
+
mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device)
|
506 |
+
|
507 |
+
pred_riskier = (
|
508 |
+
pl_module.predict_step(
|
509 |
+
(
|
510 |
+
normalized_input,
|
511 |
+
mask_input,
|
512 |
+
map,
|
513 |
+
mask_map,
|
514 |
+
offset,
|
515 |
+
ego_history,
|
516 |
+
ego_future,
|
517 |
+
),
|
518 |
+
0,
|
519 |
+
risk_level=self._risk_sampler.get_highest_risk(
|
520 |
+
batch_size=self.n_samples, device=device
|
521 |
+
)
|
522 |
+
.unsqueeze(1)
|
523 |
+
.repeat(1, n_agents),
|
524 |
+
)
|
525 |
+
.cpu()
|
526 |
+
.detach()
|
527 |
+
.numpy()
|
528 |
+
)
|
529 |
+
|
530 |
+
pred = (
|
531 |
+
pl_module.predict_step(
|
532 |
+
(
|
533 |
+
normalized_input,
|
534 |
+
mask_input,
|
535 |
+
map,
|
536 |
+
mask_map,
|
537 |
+
offset,
|
538 |
+
ego_history,
|
539 |
+
ego_future,
|
540 |
+
),
|
541 |
+
0,
|
542 |
+
risk_level=None,
|
543 |
+
)
|
544 |
+
.cpu()
|
545 |
+
.detach()
|
546 |
+
.numpy()
|
547 |
+
)
|
548 |
+
|
549 |
+
fig, ax = plt.subplots()
|
550 |
+
plotter = ScenePlotter(scene, ax=ax)
|
551 |
+
fig.set_size_inches(h=scene.road_width / 3 + 1, w=scene.road_length / 3)
|
552 |
+
|
553 |
+
time = self.dt * self.num_steps
|
554 |
+
plotter.draw_scene(0, time=time)
|
555 |
+
alpha = 0.5 / np.log(self.n_samples)
|
556 |
+
plotter.draw_all_trajectories(
|
557 |
+
ped_trajs[..., self.num_steps :, :],
|
558 |
+
color="g",
|
559 |
+
alpha=alpha,
|
560 |
+
label="Future ground truth",
|
561 |
+
)
|
562 |
+
plotter.draw_all_trajectories(
|
563 |
+
input_traj, color="b", alpha=alpha, label="Past input"
|
564 |
+
)
|
565 |
+
plotter.draw_all_trajectories(
|
566 |
+
pred, color="orange", alpha=alpha, label="Prediction"
|
567 |
+
)
|
568 |
+
plotter.draw_all_trajectories(
|
569 |
+
pred_riskier, color="r", alpha=alpha, label="Prediction risk-seeking"
|
570 |
+
)
|
571 |
+
plotter.draw_legend()
|
572 |
+
plt.tight_layout()
|
573 |
+
wandb.log({"Road scene " + name: wandb.Image(fig), "epoch": epoch})
|
574 |
+
plt.close()
|
575 |
+
|
576 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
577 |
+
"""After a validation at the end of every plot_interval_epoch,
|
578 |
+
log the prediction samples for two scenes: the safer fast scene and the safer slow scene.
|
579 |
+
"""
|
580 |
+
if (
|
581 |
+
trainer.current_epoch % self.plot_interval_epoch
|
582 |
+
== self.plot_interval_epoch - 1
|
583 |
+
):
|
584 |
+
self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
|
585 |
+
self.scene_params, self.n_samples
|
586 |
+
)
|
587 |
+
self._log_scene(
|
588 |
+
trainer.current_epoch, pl_module, self.scene_safe_slow, "Safer slow"
|
589 |
+
)
|
590 |
+
self._log_scene(
|
591 |
+
trainer.current_epoch, pl_module, self.scene_safe_fast, "Safer fast"
|
592 |
+
)
|
593 |
+
|
594 |
+
|
595 |
+
# TODO: make the same kind of logs for the Waymo dataset
|
risk_biased/utils/config_argparse.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Union, List
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
from mmcv import Config
|
7 |
+
|
8 |
+
|
9 |
+
def config_argparse(config_path: Optional[Union[str, List[str]]] = None) -> Config:
|
10 |
+
"""Function that loads the config file as an MMCV Config object and overwrites its values with argparsed arguments.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
config_path : path of the mmcv config file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
MMCV Config object with default values from the config_path and overwritten values from argparse
|
17 |
+
"""
|
18 |
+
if config_path is None:
|
19 |
+
working_dir = os.path.dirname(os.path.realpath(__file__))
|
20 |
+
config_path = os.path.join(
|
21 |
+
working_dir, "..", "..", "config", "learning_config.py"
|
22 |
+
)
|
23 |
+
if isinstance(config_path, str):
|
24 |
+
cfg = Config.fromfile(config_path)
|
25 |
+
else:
|
26 |
+
cfg = Config.fromfile(config_path[0])
|
27 |
+
for path in config_path[1:]:
|
28 |
+
c = Config.fromfile(path)
|
29 |
+
cfg.update(c)
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
excluded_args = ["force_config", "load_last"]
|
33 |
+
overwritable_types = (str, float, int, list)
|
34 |
+
for key, value in cfg.items():
|
35 |
+
if key not in excluded_args:
|
36 |
+
if list in overwritable_types and isinstance(value, list):
|
37 |
+
if len(value) > 0:
|
38 |
+
parser.add_argument(
|
39 |
+
"--" + key, default=value, nargs="+", type=type(value[0])
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
parser.add_argument("--" + key, default=value, nargs="+")
|
43 |
+
elif isinstance(value, overwritable_types):
|
44 |
+
parser.add_argument("--" + key, default=value, type=type(value))
|
45 |
+
|
46 |
+
if "load_from" not in cfg.keys():
|
47 |
+
parser.add_argument(
|
48 |
+
"--load_from",
|
49 |
+
default="",
|
50 |
+
type=str,
|
51 |
+
help="""Use this to load the model weights from a wandb checkpoint,
|
52 |
+
refer to the checkpoint with the wandb id (example:'1f1ho81a')""",
|
53 |
+
)
|
54 |
+
|
55 |
+
parser.add_argument(
|
56 |
+
"--load_last",
|
57 |
+
action="store_true",
|
58 |
+
help="""Use this flag to force the use of the last checkpoint instead of the best one
|
59 |
+
when loading a model.""",
|
60 |
+
)
|
61 |
+
|
62 |
+
parser.add_argument(
|
63 |
+
"--force_config",
|
64 |
+
action="store_true",
|
65 |
+
help="""Use this flag to force the use of the local config file
|
66 |
+
when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
|
67 |
+
In any case the parameters can be overwritten with an argparse argument.""",
|
68 |
+
)
|
69 |
+
if "force_config" not in cfg.keys():
|
70 |
+
parser.set_defaults(force_config=False)
|
71 |
+
else:
|
72 |
+
parser.set_defaults(force_config=cfg.force_config)
|
73 |
+
|
74 |
+
if "load_last" not in cfg.keys():
|
75 |
+
parser.set_defaults(load_last=False)
|
76 |
+
else:
|
77 |
+
parser.set_defaults(force_config=cfg.force_config)
|
78 |
+
|
79 |
+
args = parser.parse_args()
|
80 |
+
|
81 |
+
# Print a warning in case the parameter 'dt' or 'time_scene' is changed becaus 'sample_times' might need to be updated accordingly.
|
82 |
+
if (
|
83 |
+
args.dt != cfg.dt or args.time_scene != cfg.time_scene
|
84 |
+
) and args.sample_times == cfg.sample_times:
|
85 |
+
warnings.warn(
|
86 |
+
f"""Parameter 'dt' has been changed from {args.dataset_parameters['dt']} to {args.dt} by
|
87 |
+
a command line argument, it might be used to set the parameter 'sample_times' that
|
88 |
+
cannot be updated accordingly. Consider setting 'dt' in {config_path} instead."""
|
89 |
+
)
|
90 |
+
# Config has a dataset_parameters field that copies the parameters related to dataset to compare them,
|
91 |
+
# they must be updated too if some of the dataset parameters were changed by argparse
|
92 |
+
for key, value in cfg.dataset_parameters.items():
|
93 |
+
if isinstance(value, overwritable_types):
|
94 |
+
cfg.dataset_parameters[key] = args.__getattribute__(key)
|
95 |
+
cfg.update(args.__dict__)
|
96 |
+
return cfg
|
risk_biased/utils/cost.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Callable, Optional, Tuple
|
3 |
+
|
4 |
+
from mmcv import Config
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
def masked_min_torch(x, mask=None, dim=None):
|
11 |
+
if mask is not None:
|
12 |
+
x = torch.masked_fill(x, torch.logical_not(mask), float("inf"))
|
13 |
+
if dim is None:
|
14 |
+
return torch.min(x)
|
15 |
+
else:
|
16 |
+
return torch.min(x, dim=dim)[0]
|
17 |
+
|
18 |
+
|
19 |
+
def masked_max_torch(x, mask=None, dim=None):
|
20 |
+
if mask is not None:
|
21 |
+
x = torch.masked_fill(x, torch.logical_not(mask), float("-inf"))
|
22 |
+
if dim is None:
|
23 |
+
return torch.max(x)
|
24 |
+
else:
|
25 |
+
return torch.max(x, dim=dim)[0]
|
26 |
+
|
27 |
+
|
28 |
+
def get_masked_discounted_mean_torch(discount_factor=0.95):
|
29 |
+
def masked_discounted_mean_torch(x, mask=None, dim=None):
|
30 |
+
discount_tensor = torch.full(x.shape, discount_factor, device=x.device)
|
31 |
+
discount_tensor = torch.cumprod(discount_tensor, dim=-2)
|
32 |
+
if mask is not None:
|
33 |
+
x = torch.masked_fill(x, torch.logical_not(mask), 0)
|
34 |
+
if dim is None:
|
35 |
+
assert mask.any()
|
36 |
+
return (x * discount_tensor).sum() / (mask * discount_tensor).sum()
|
37 |
+
else:
|
38 |
+
return (x * discount_tensor).sum(dim) / (mask * discount_tensor).sum(
|
39 |
+
dim
|
40 |
+
).clamp_min(1)
|
41 |
+
else:
|
42 |
+
if dim is None:
|
43 |
+
return (x * discount_tensor).sum() / discount_tensor.sum()
|
44 |
+
else:
|
45 |
+
return (x * discount_tensor).sum(dim) / discount_tensor.sum(dim)
|
46 |
+
|
47 |
+
return masked_discounted_mean_torch
|
48 |
+
|
49 |
+
|
50 |
+
def masked_mean_torch(x, mask=None, dim=None):
|
51 |
+
if mask is not None:
|
52 |
+
x = torch.masked_fill(x, torch.logical_not(mask), 0)
|
53 |
+
if dim is None:
|
54 |
+
assert mask.any()
|
55 |
+
return x.sum() / mask.sum()
|
56 |
+
else:
|
57 |
+
return x.sum(dim) / mask.sum(dim).clamp_min(1)
|
58 |
+
else:
|
59 |
+
if dim is None:
|
60 |
+
return x.mean()
|
61 |
+
else:
|
62 |
+
return x.mean(dim)
|
63 |
+
|
64 |
+
|
65 |
+
def get_discounted_mean_np(discount_factor=0.95):
|
66 |
+
def discounted_mean_np(x, axis=None):
|
67 |
+
discount_tensor = np.full(x.shape, discount_factor)
|
68 |
+
discount_tensor = np.cumprod(discount_tensor, axis=-2)
|
69 |
+
if axis is None:
|
70 |
+
return (x * discount_tensor).sum() / discount_tensor.sum()
|
71 |
+
else:
|
72 |
+
return (x * discount_tensor).sum(axis) / discount_tensor.sum(axis)
|
73 |
+
|
74 |
+
return discounted_mean_np
|
75 |
+
|
76 |
+
|
77 |
+
def get_masked_reduce_np(reduce_function):
|
78 |
+
def masked_reduce_np(x, mask=None, axis=None):
|
79 |
+
if mask is not None:
|
80 |
+
x = np.ma.array(x, mask=np.logical_not(mask))
|
81 |
+
return reduce_function(x, axis=axis)
|
82 |
+
else:
|
83 |
+
return reduce_function(x, axis=axis)
|
84 |
+
|
85 |
+
return masked_reduce_np
|
86 |
+
|
87 |
+
|
88 |
+
@dataclass
|
89 |
+
class CostParams:
|
90 |
+
scale: float
|
91 |
+
reduce: str
|
92 |
+
discount_factor: float
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def from_config(cfg: Config):
|
96 |
+
return CostParams(
|
97 |
+
scale=cfg.cost_scale,
|
98 |
+
reduce=cfg.cost_reduce,
|
99 |
+
discount_factor=cfg.discount_factor,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class BaseCostTorch:
|
104 |
+
"""Base cost class defining reduce strategy and basic parameters.
|
105 |
+
Its __call__ definition is only a dummy example returning zeros, this class is intended to be
|
106 |
+
inherited from and __call__ redefined with an actual cost between the inputs.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, params: CostParams) -> None:
|
110 |
+
super().__init__()
|
111 |
+
self._reduce_fun = params.reduce
|
112 |
+
self.scale = params.scale
|
113 |
+
|
114 |
+
reduce_fun_torch_dict = {
|
115 |
+
"min": masked_min_torch,
|
116 |
+
"max": masked_max_torch,
|
117 |
+
"mean": masked_mean_torch,
|
118 |
+
"discounted_mean": get_masked_discounted_mean_torch(params.discount_factor),
|
119 |
+
"now": lambda *args, **kwargs: args[0][..., 0],
|
120 |
+
"final": lambda *args, **kwargs: args[0][..., -1],
|
121 |
+
}
|
122 |
+
|
123 |
+
self._reduce_fun = reduce_fun_torch_dict[params.reduce]
|
124 |
+
|
125 |
+
@property
|
126 |
+
def distance_bandwidth(self):
|
127 |
+
return 1
|
128 |
+
|
129 |
+
@property
|
130 |
+
def time_bandwidth(self):
|
131 |
+
return 1
|
132 |
+
|
133 |
+
def __call__(
|
134 |
+
self,
|
135 |
+
x1: Tensor,
|
136 |
+
x2: Tensor,
|
137 |
+
v1: Tensor,
|
138 |
+
v2: Tensor,
|
139 |
+
mask: Optional[Tensor] = None,
|
140 |
+
) -> Tuple[Tensor, Any]:
|
141 |
+
"""Compute the cost from given positions x1, x2 and velocities v1, v2
|
142 |
+
The base cost only returns 0 cost, use costs that inherit from this to compute an actual cost.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
x1 (some shape, num_steps, 2): positions of the first agent
|
146 |
+
x2 (some shape, num_steps, 2): positions of the second agent
|
147 |
+
v1 (some shape, num_steps, 2): velocities of the first agent
|
148 |
+
v2 (some shape, num_steps, 2): velocities of the second agent
|
149 |
+
mask (some_shape, num_steps, 2): mask set to True where the cost should be computed
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
(some_shape) cost for the compared states of agent 1 and agent 2, as well as any
|
153 |
+
supplementary cost-related information
|
154 |
+
"""
|
155 |
+
return (
|
156 |
+
self._reduce_fun(torch.zeros_like(x2[..., 0]), mask, dim=-1),
|
157 |
+
None,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
class BaseCostNumpy:
|
162 |
+
"""Base cost class defining reduce strategy and basic parameters.
|
163 |
+
Its __call__ definition is only a dummy example returning zeros, this class is intended to be
|
164 |
+
inherited from and __call__ redefined with an actual cost between the inputs.
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, params: CostParams) -> None:
|
168 |
+
super().__init__()
|
169 |
+
self._reduce_fun = params.reduce
|
170 |
+
self.scale = params.scale
|
171 |
+
|
172 |
+
reduce_fun_np_dict = {
|
173 |
+
"min": get_masked_reduce_np(np.min),
|
174 |
+
"max": get_masked_reduce_np(np.max),
|
175 |
+
"mean": get_masked_reduce_np(np.mean),
|
176 |
+
"discounted_mean": get_masked_reduce_np(
|
177 |
+
get_discounted_mean_np(params.discount_factor)
|
178 |
+
),
|
179 |
+
"now": get_masked_reduce_np(lambda *args, **kwargs: args[0][..., 0]),
|
180 |
+
"final": get_masked_reduce_np(lambda *args, **kwargs: args[0][..., -1]),
|
181 |
+
}
|
182 |
+
self._reduce_fun = reduce_fun_np_dict[params.reduce]
|
183 |
+
|
184 |
+
@property
|
185 |
+
def distance_bandwidth(self):
|
186 |
+
return 1
|
187 |
+
|
188 |
+
@property
|
189 |
+
def time_bandwidth(self):
|
190 |
+
return 1
|
191 |
+
|
192 |
+
def __call__(
|
193 |
+
self,
|
194 |
+
x1: np.ndarray,
|
195 |
+
x2: np.ndarray,
|
196 |
+
v1: np.ndarray,
|
197 |
+
v2: np.ndarray,
|
198 |
+
mask: Optional[np.ndarray] = None,
|
199 |
+
) -> Tuple[np.ndarray, Any]:
|
200 |
+
"""Compute the cost from given positions x1, x2 and velocities v1, v2
|
201 |
+
The base cost only returns 0 cost, use costs that inherit from this to compute an actual cost.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
x1 (some shape, num_steps, 2): positions of the first agent
|
205 |
+
x2 (some shape, num_steps, 2): positions of the second agent
|
206 |
+
v1 (some shape, num_steps, 2): velocities of the first agent
|
207 |
+
v2 (some shape, num_steps, 2): velocities of the second agent
|
208 |
+
mask (some_shape, num_steps, 2): mask set to True where the cost should be computed
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
(some_shape) cost for the compared states of agent 1 and agent 2, as well as any
|
212 |
+
supplementary cost-related information
|
213 |
+
"""
|
214 |
+
return (
|
215 |
+
self._reduce_fun(np.zeros_like(x2[..., 0]), mask, axis=-1),
|
216 |
+
None,
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
@dataclass
|
221 |
+
class DistanceCostParams(CostParams):
|
222 |
+
bandwidth: float
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def from_config(cfg: Config):
|
226 |
+
return DistanceCostParams(
|
227 |
+
scale=cfg.cost_scale,
|
228 |
+
reduce=cfg.cost_reduce,
|
229 |
+
bandwidth=cfg.distance_bandwidth,
|
230 |
+
discount_factor=cfg.discount_factor,
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
class DistanceCostTorch(BaseCostTorch):
|
235 |
+
def __init__(self, params: DistanceCostParams) -> None:
|
236 |
+
super().__init__(params)
|
237 |
+
self._bandwidth = params.bandwidth
|
238 |
+
|
239 |
+
@property
|
240 |
+
def distance_bandwidth(self):
|
241 |
+
return self._bandwidth
|
242 |
+
|
243 |
+
def __call__(
|
244 |
+
self, x1: Tensor, x2: Tensor, *args, mask: Optional[Tensor] = None, **kwargs
|
245 |
+
) -> Tuple[Tensor, Tensor]:
|
246 |
+
"""
|
247 |
+
Returns a cost estimation based on distance. Also returns distances between ego and pedestrians.
|
248 |
+
Args:
|
249 |
+
x1: First agent trajectory
|
250 |
+
x2: Second agent trajectory
|
251 |
+
mask: True where cost should be computed
|
252 |
+
Returns:
|
253 |
+
cost, distance_to_collision
|
254 |
+
"""
|
255 |
+
|
256 |
+
dist = torch.square(x2 - x1).sum(-1)
|
257 |
+
if mask is not None:
|
258 |
+
dist = torch.masked_fill(dist, torch.logical_not(mask), 1e9)
|
259 |
+
cost = torch.exp(-dist / (2 * self._bandwidth))
|
260 |
+
return self.scale * self._reduce_fun(cost, mask=mask, dim=-1), dist
|
261 |
+
|
262 |
+
|
263 |
+
class DistanceCostNumpy(BaseCostNumpy):
|
264 |
+
def __init__(self, params: DistanceCostParams) -> None:
|
265 |
+
super().__init__(params)
|
266 |
+
self._bandwidth = params.bandwidth
|
267 |
+
|
268 |
+
@property
|
269 |
+
def distance_bandwidth(self):
|
270 |
+
return self._bandwidth
|
271 |
+
|
272 |
+
def __call__(
|
273 |
+
self,
|
274 |
+
x1: np.ndarray,
|
275 |
+
x2: np.ndarray,
|
276 |
+
*args,
|
277 |
+
mask: Optional[np.ndarray] = None,
|
278 |
+
**kwargs
|
279 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
280 |
+
"""
|
281 |
+
Returns a cost estimation based on distance. Also returns distances between ego and pedestrians.
|
282 |
+
Args:
|
283 |
+
x1: First agent trajectory
|
284 |
+
x2: Second agent trajectory
|
285 |
+
mask: True where cost should be computed
|
286 |
+
Returns:
|
287 |
+
cost, distance_to_collision
|
288 |
+
"""
|
289 |
+
dist = np.square(x2 - x1).sum(-1)
|
290 |
+
if mask is not None:
|
291 |
+
dist = np.where(mask, dist, 1e9)
|
292 |
+
cost = np.exp(-dist / (2 * self._bandwidth))
|
293 |
+
return self.scale * self._reduce_fun(cost, mask=mask, axis=-1), dist
|
294 |
+
|
295 |
+
|
296 |
+
@dataclass
|
297 |
+
class TTCCostParams(CostParams):
|
298 |
+
distance_bandwidth: float
|
299 |
+
time_bandwidth: float
|
300 |
+
min_velocity_diff: float
|
301 |
+
|
302 |
+
@staticmethod
|
303 |
+
def from_config(cfg: Config):
|
304 |
+
return TTCCostParams(
|
305 |
+
scale=cfg.cost_scale,
|
306 |
+
reduce=cfg.cost_reduce,
|
307 |
+
distance_bandwidth=cfg.distance_bandwidth,
|
308 |
+
time_bandwidth=cfg.time_bandwidth,
|
309 |
+
min_velocity_diff=cfg.min_velocity_diff,
|
310 |
+
discount_factor=cfg.discount_factor,
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
class TTCCostTorch(BaseCostTorch):
|
315 |
+
def __init__(self, params: TTCCostParams) -> None:
|
316 |
+
super().__init__(params)
|
317 |
+
self._d_bw = params.distance_bandwidth
|
318 |
+
self._t_bw = params.time_bandwidth
|
319 |
+
self._min_v = params.min_velocity_diff
|
320 |
+
|
321 |
+
@property
|
322 |
+
def distance_bandwidth(self):
|
323 |
+
return self._d_bw
|
324 |
+
|
325 |
+
@property
|
326 |
+
def time_bandwidth(self):
|
327 |
+
return self._t_bw
|
328 |
+
|
329 |
+
def __call__(
|
330 |
+
self,
|
331 |
+
x1: Tensor,
|
332 |
+
x2: Tensor,
|
333 |
+
v1: Tensor,
|
334 |
+
v2: Tensor,
|
335 |
+
*args,
|
336 |
+
mask: Optional[Tensor] = None,
|
337 |
+
**kwargs
|
338 |
+
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
339 |
+
"""
|
340 |
+
Returns a cost estimation based on time to collision and distance to collision.
|
341 |
+
Also returns the estimated time to collision, and the imaginary part of the time to collision.
|
342 |
+
Args:
|
343 |
+
x1: (some_shape, sequence_length, feature_shape) Initial position of the first agent
|
344 |
+
x2: (some_shape, sequence_length, feature_shape) Initial position of the second agent
|
345 |
+
v1: (some_shape, sequence_length, feature_shape) Velocity of the first agent
|
346 |
+
v2: (some_shape, sequence_length, feature_shape) Velocity of the second agent
|
347 |
+
mask: (some_shape, sequence_length) True where cost should be computed
|
348 |
+
Returns:
|
349 |
+
cost, (time_to_collision, distance_to_collision)
|
350 |
+
"""
|
351 |
+
pos_diff = x1 - x2
|
352 |
+
velocity_diff = v1 - v2
|
353 |
+
|
354 |
+
dx = pos_diff[..., 0]
|
355 |
+
dy = pos_diff[..., 1]
|
356 |
+
vx = velocity_diff[..., 0]
|
357 |
+
vy = velocity_diff[..., 1]
|
358 |
+
|
359 |
+
speed_diff = (
|
360 |
+
torch.square(velocity_diff).sum(-1).clamp(self._min_v * self._min_v, None)
|
361 |
+
)
|
362 |
+
|
363 |
+
TTC = -(dx * vx + dy * vy) / speed_diff
|
364 |
+
|
365 |
+
distance_TTC = torch.where(
|
366 |
+
TTC < 0,
|
367 |
+
torch.sqrt(dx * dx + dy * dy),
|
368 |
+
torch.abs(vy * dx - vx * dy) / torch.sqrt(speed_diff),
|
369 |
+
)
|
370 |
+
TTC = torch.relu(TTC)
|
371 |
+
if mask is not None:
|
372 |
+
TTC = torch.masked_fill(TTC, torch.logical_not(mask), 1e9)
|
373 |
+
distance_TTC = torch.masked_fill(distance_TTC, torch.logical_not(mask), 1e9)
|
374 |
+
|
375 |
+
cost = self.scale * self._reduce_fun(
|
376 |
+
torch.exp(
|
377 |
+
-torch.square(TTC) / (2 * self._t_bw)
|
378 |
+
- torch.square(distance_TTC) / (2 * self._d_bw)
|
379 |
+
),
|
380 |
+
mask=mask,
|
381 |
+
dim=-1,
|
382 |
+
)
|
383 |
+
|
384 |
+
return cost, (TTC, distance_TTC)
|
385 |
+
|
386 |
+
|
387 |
+
class TTCCostNumpy(BaseCostNumpy):
|
388 |
+
def __init__(self, params: TTCCostParams) -> None:
|
389 |
+
super().__init__(params)
|
390 |
+
self._d_bw = params.distance_bandwidth
|
391 |
+
self._t_bw = params.time_bandwidth
|
392 |
+
self._min_v = params.min_velocity_diff
|
393 |
+
|
394 |
+
@property
|
395 |
+
def distance_bandwidth(self):
|
396 |
+
return self._d_bw
|
397 |
+
|
398 |
+
@property
|
399 |
+
def time_bandwidth(self):
|
400 |
+
return self._t_bw
|
401 |
+
|
402 |
+
def __call__(
|
403 |
+
self,
|
404 |
+
x1: np.ndarray,
|
405 |
+
x2: np.ndarray,
|
406 |
+
v1: np.ndarray,
|
407 |
+
v2: np.ndarray,
|
408 |
+
*args,
|
409 |
+
mask: Optional[np.ndarray] = None,
|
410 |
+
**kwargs
|
411 |
+
) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
412 |
+
"""
|
413 |
+
Returns a cost estimation based on time to collision and distance to collision.
|
414 |
+
Also returns the estimated time to collision, and the imaginary part of the time to collision.
|
415 |
+
Args:
|
416 |
+
x1: (some_shape, sequence_length, feature_shape) Initial position of the first agent
|
417 |
+
x2: (some_shape, sequence_length, feature_shape) Initial position of the second agent
|
418 |
+
v1: (some_shape, sequence_length, feature_shape) Velocity of the first agent
|
419 |
+
v2: (some_shape, sequence_length, feature_shape) Velocity of the second agent
|
420 |
+
mask: (some_shape, sequence_length) True where cost should be computed
|
421 |
+
Returns:
|
422 |
+
cost, (time_to_collision, distance_to_collision)
|
423 |
+
"""
|
424 |
+
pos_diff = x1 - x2
|
425 |
+
velocity_diff = v1 - v2
|
426 |
+
|
427 |
+
dx = pos_diff[..., 0]
|
428 |
+
dy = pos_diff[..., 1]
|
429 |
+
vx = velocity_diff[..., 0]
|
430 |
+
vy = velocity_diff[..., 1]
|
431 |
+
|
432 |
+
speed_diff = np.maximum(
|
433 |
+
np.square(velocity_diff).sum(-1), self._min_v * self._min_v
|
434 |
+
)
|
435 |
+
|
436 |
+
TTC = -(dx * vx + dy * vy) / speed_diff
|
437 |
+
distance_TTC = np.where(
|
438 |
+
TTC < 0,
|
439 |
+
np.sqrt(dx * dx + dy * dy),
|
440 |
+
np.abs(vy * dx - vx * dy) / np.sqrt(speed_diff),
|
441 |
+
)
|
442 |
+
TTC = np.where(
|
443 |
+
TTC < 0,
|
444 |
+
0,
|
445 |
+
TTC,
|
446 |
+
)
|
447 |
+
if mask is not None:
|
448 |
+
TTC = np.where(mask, TTC, 1e9)
|
449 |
+
distance_TTC = np.where(mask, TTC, 1e9)
|
450 |
+
|
451 |
+
cost = self.scale * self._reduce_fun(
|
452 |
+
np.exp(
|
453 |
+
-np.square(TTC) / (2 * self._t_bw)
|
454 |
+
- np.square(distance_TTC) / (2 * self._d_bw)
|
455 |
+
),
|
456 |
+
mask=mask,
|
457 |
+
axis=-1,
|
458 |
+
)
|
459 |
+
return cost, (TTC, distance_TTC)
|
460 |
+
|
461 |
+
|
462 |
+
def compute_v_from_x(x: Tensor, y: Tensor, dt: float):
|
463 |
+
"""
|
464 |
+
Computes the velocity from the position and the time difference.
|
465 |
+
Args:
|
466 |
+
x: (some_shape, past_time_sequence, features) Past positions of the agents
|
467 |
+
y: (some_shape, future_time_sequence, features) Future positions of the agents
|
468 |
+
dt: Time difference
|
469 |
+
Returns:
|
470 |
+
v: (some_shape, future_time_sequence, features) Velocity of the agents
|
471 |
+
"""
|
472 |
+
v = (y[..., 1:, :2] - y[..., :-1, :2]) / dt
|
473 |
+
v_0 = (y[..., 0:1, :2] - x[..., -1:, :2]) / dt
|
474 |
+
v = torch.cat((v_0, v), -2)
|
475 |
+
return v
|
476 |
+
|
477 |
+
|
478 |
+
def get_cost(
|
479 |
+
cost_function: BaseCostTorch,
|
480 |
+
x: torch.Tensor,
|
481 |
+
y_samples: torch.Tensor,
|
482 |
+
offset: torch.Tensor,
|
483 |
+
x_ego: torch.Tensor,
|
484 |
+
y_ego: torch.Tensor,
|
485 |
+
dt: float,
|
486 |
+
unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
487 |
+
mask: Optional[torch.Tensor] = None,
|
488 |
+
) -> torch.Tensor:
|
489 |
+
"""Compute cost samples from predicted future trajectories
|
490 |
+
|
491 |
+
Args:
|
492 |
+
cost_function: Cost function to use
|
493 |
+
x: (batch_size, n_agents, num_steps, state_dim) normalized tensor of history
|
494 |
+
y_samples: (batch_size, n_agents, n_samples, num_steps_future, state_dim) normalized tensor of predicted
|
495 |
+
future trajectory samples
|
496 |
+
offset: (batch_size, n_agents, state_dim) offset position from ego
|
497 |
+
x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history
|
498 |
+
y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory
|
499 |
+
dt: time step in trajectories
|
500 |
+
unnormalizer: function that takes in a trajectory and an offset and that outputs the
|
501 |
+
unnormalized trajectory
|
502 |
+
mask: tensor indicating where to compute the cost
|
503 |
+
Returns:
|
504 |
+
torch.Tensor: (batch_size, n_agents, n_samples) cost tensor
|
505 |
+
"""
|
506 |
+
x = unnormalizer(x, offset)
|
507 |
+
y_samples = unnormalizer(y_samples, offset)
|
508 |
+
if offset.shape[1] > 1:
|
509 |
+
x_ego = unnormalizer(x_ego, offset[:, 0:1])
|
510 |
+
y_ego = unnormalizer(y_ego, offset[:, 0:1])
|
511 |
+
|
512 |
+
min_dim = min(x.shape[-1], y_samples.shape[-1])
|
513 |
+
x = x[..., :min_dim]
|
514 |
+
y_samples = y_samples[..., :min_dim]
|
515 |
+
x_ego = x_ego[..., :min_dim]
|
516 |
+
y_ego = y_ego[..., :min_dim]
|
517 |
+
assert x_ego.ndim == y_ego.ndim
|
518 |
+
if y_samples.shape[-1] < 5:
|
519 |
+
v_samples = compute_v_from_x(x.unsqueeze(-3), y_samples, dt)
|
520 |
+
else:
|
521 |
+
v_samples = y_samples[..., 3:5]
|
522 |
+
|
523 |
+
if y_ego.shape[-1] < 5:
|
524 |
+
v_ego = compute_v_from_x(x_ego, y_ego, dt)
|
525 |
+
else:
|
526 |
+
v_ego = y_ego[..., 3:5]
|
527 |
+
if mask is not None:
|
528 |
+
mask = torch.cat(
|
529 |
+
(mask[..., 0:1], torch.logical_and(mask[..., 1:], mask[..., :-1])), -1
|
530 |
+
)
|
531 |
+
|
532 |
+
cost, _ = cost_function(
|
533 |
+
x1=y_ego.unsqueeze(-3),
|
534 |
+
x2=y_samples,
|
535 |
+
v1=v_ego.unsqueeze(-3),
|
536 |
+
v2=v_samples,
|
537 |
+
mask=mask,
|
538 |
+
)
|
539 |
+
return cost
|
risk_biased/utils/load_model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
import os
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
from mmcv import Config
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
|
10 |
+
from risk_biased.predictors.biased_predictor import (
|
11 |
+
LitTrajectoryPredictor,
|
12 |
+
LitTrajectoryPredictorParams,
|
13 |
+
)
|
14 |
+
|
15 |
+
from risk_biased.utils.config_argparse import config_argparse
|
16 |
+
from risk_biased.utils.cost import TTCCostParams
|
17 |
+
from risk_biased.utils.torch_utils import load_weights
|
18 |
+
|
19 |
+
from risk_biased.scene_dataset.loaders import SceneDataLoaders
|
20 |
+
from risk_biased.scene_dataset.scene import load_create_dataset
|
21 |
+
|
22 |
+
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
|
23 |
+
|
24 |
+
|
25 |
+
def get_predictor(
|
26 |
+
config: Config, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
27 |
+
):
|
28 |
+
params = LitTrajectoryPredictorParams.from_config(config)
|
29 |
+
model_class = LitTrajectoryPredictor
|
30 |
+
ttc_params = TTCCostParams.from_config(config)
|
31 |
+
return model_class(params=params, unnormalizer=unnormalizer, cost_params=ttc_params)
|
32 |
+
|
33 |
+
|
34 |
+
def load_from_wandb_id(
|
35 |
+
log_id: str,
|
36 |
+
log_path: str,
|
37 |
+
entity: str,
|
38 |
+
project: str,
|
39 |
+
config: Optional[Config] = None,
|
40 |
+
load_last=False,
|
41 |
+
) -> Tuple[Union[LitTrajectoryPredictor, LitTrajectoryPredictor], Config]:
|
42 |
+
"""
|
43 |
+
Load a model using a wandb id code.
|
44 |
+
Args:
|
45 |
+
log_id: the wandb id code
|
46 |
+
log_path: the wandb log directory path
|
47 |
+
config: An optional configuration argument, use these settings if not None, use the settings from the log directory otherwise
|
48 |
+
load_last: An optional argumument, set to True to load the last checkpoint instead of the best one
|
49 |
+
Returns:
|
50 |
+
Predictor model and config file either loaded from the checkpoint or the one passed as argument.
|
51 |
+
"""
|
52 |
+
list_matching = list(filter(lambda path: log_id in path, os.listdir(log_path)))
|
53 |
+
if len(list_matching) == 1:
|
54 |
+
list_ckpt = list(
|
55 |
+
filter(
|
56 |
+
lambda path: "epoch" in path and ".ckpt" in path,
|
57 |
+
os.listdir(os.path.join(log_path, list_matching[0], "files")),
|
58 |
+
)
|
59 |
+
)
|
60 |
+
if not load_last and len(list_ckpt) == 1:
|
61 |
+
print(f"Loading best model: {list_ckpt[0]}.")
|
62 |
+
checkpoint_path = os.path.join(
|
63 |
+
log_path, list_matching[0], "files", list_ckpt[0]
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
print(f"Loading last checkpoint.")
|
67 |
+
checkpoint_path = os.path.join(
|
68 |
+
log_path, list_matching[0], "files", "last.ckpt"
|
69 |
+
)
|
70 |
+
config_path = os.path.join(
|
71 |
+
log_path, list_matching[0], "files", "learning_config.py"
|
72 |
+
)
|
73 |
+
|
74 |
+
if config is None:
|
75 |
+
config = config_argparse(config_path)
|
76 |
+
distant_model_type = None
|
77 |
+
else:
|
78 |
+
distant_config = config_argparse(config_path)
|
79 |
+
distant_model_type = distant_config.model_type
|
80 |
+
config["load_from"] = log_id
|
81 |
+
|
82 |
+
if config.model_type == "interaction_biased":
|
83 |
+
dataloaders = WaymoDataloaders(config)
|
84 |
+
else:
|
85 |
+
[data_train, data_val, data_test] = load_create_dataset(config)
|
86 |
+
dataloaders = SceneDataLoaders(
|
87 |
+
state_dim=config.state_dim,
|
88 |
+
num_steps=config.num_steps,
|
89 |
+
num_steps_future=config.num_steps_future,
|
90 |
+
batch_size=config.batch_size,
|
91 |
+
data_train=data_train,
|
92 |
+
data_val=data_val,
|
93 |
+
data_test=data_test,
|
94 |
+
num_workers=config.num_workers,
|
95 |
+
)
|
96 |
+
|
97 |
+
try:
|
98 |
+
if len(config.gpus):
|
99 |
+
map_location = "cpu"
|
100 |
+
else:
|
101 |
+
map_location = "gpu"
|
102 |
+
model = load_weights(
|
103 |
+
get_predictor(config, dataloaders.unnormalize_trajectory),
|
104 |
+
torch.load(checkpoint_path, map_location=map_location),
|
105 |
+
strict=True,
|
106 |
+
)
|
107 |
+
except RuntimeError:
|
108 |
+
raise RuntimeError(
|
109 |
+
f"The source model is of type {distant_model_type}."
|
110 |
+
+ " It cannot be used to load the weights of the interaction biased model."
|
111 |
+
)
|
112 |
+
|
113 |
+
return model, dataloaders, config
|
114 |
+
|
115 |
+
else:
|
116 |
+
print("Trying to download logs from WandB...")
|
117 |
+
api = wandb.Api()
|
118 |
+
run = api.run(entity + "/" + project + "/" + log_id)
|
119 |
+
if run is not None:
|
120 |
+
checkpoint_path = os.path.join(
|
121 |
+
log_path, "downloaded_run-" + log_id, "files"
|
122 |
+
)
|
123 |
+
os.makedirs(checkpoint_path)
|
124 |
+
for file in run.files():
|
125 |
+
if file.name.endswith("ckpt") or file.name.endswith("config.py"):
|
126 |
+
file.download(checkpoint_path)
|
127 |
+
return load_from_wandb_id(
|
128 |
+
log_id, log_path, entity, project, config, load_last
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
raise RuntimeError(
|
132 |
+
f"Error while loading checkpoint: Found {len(list_matching)} occurences of the given id {log_id} in the logs at {log_path}."
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
def load_from_config(cfg: Config):
|
137 |
+
"""
|
138 |
+
This function loads the predictor model and the data depending on which one is selected in the config.
|
139 |
+
If a "load_from" field is not empty, then tries to load the pre-trained model from the checkpoint.
|
140 |
+
The matching config file is loaded
|
141 |
+
|
142 |
+
Args:
|
143 |
+
cfg : Configuration that defines the model to be loaded
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
loaded model and a new version of the config that is compatible with the checkpoint model that it could be loaded from
|
147 |
+
"""
|
148 |
+
|
149 |
+
log_path = os.path.join(cfg.log_path, "wandb")
|
150 |
+
ignored_keys = [
|
151 |
+
"project",
|
152 |
+
"dataset_parameters",
|
153 |
+
"load_from",
|
154 |
+
"force_config",
|
155 |
+
"load_last",
|
156 |
+
]
|
157 |
+
|
158 |
+
if "load_from" in cfg.keys() and cfg.load_from != "" and cfg.load_from:
|
159 |
+
if "load_last" in cfg.keys():
|
160 |
+
load_last = cfg["load_last"]
|
161 |
+
else:
|
162 |
+
load_last = False
|
163 |
+
if cfg.force_config:
|
164 |
+
warnings.warn(
|
165 |
+
f"Using local configuration but loading from run {cfg.load_from}. Will fail if local configuration is not compatible."
|
166 |
+
)
|
167 |
+
predictor, dataloaders, config = load_from_wandb_id(
|
168 |
+
log_id=cfg.load_from,
|
169 |
+
log_path=log_path,
|
170 |
+
entity=cfg.entity,
|
171 |
+
project=cfg.project,
|
172 |
+
config=cfg,
|
173 |
+
load_last=load_last,
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
predictor, dataloaders, config = load_from_wandb_id(
|
177 |
+
log_id=cfg.load_from,
|
178 |
+
log_path=log_path,
|
179 |
+
entity=cfg.entity,
|
180 |
+
project=cfg.project,
|
181 |
+
load_last=load_last,
|
182 |
+
)
|
183 |
+
difference = False
|
184 |
+
warning_message = ""
|
185 |
+
for key, item in cfg.items():
|
186 |
+
try:
|
187 |
+
if config[key] != item:
|
188 |
+
if not difference:
|
189 |
+
warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded.\n"
|
190 |
+
difference = True
|
191 |
+
if key not in ignored_keys:
|
192 |
+
warning_message += f" The value of '{key}' is now '{config[key]}' instead of '{item}'."
|
193 |
+
except KeyError:
|
194 |
+
if not difference:
|
195 |
+
warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded."
|
196 |
+
difference = True
|
197 |
+
warning_message += f" The parameter '{key}' with value '{item}' does not exist for the model you are loading from, it is added."
|
198 |
+
config[key] = item
|
199 |
+
if warning_message != "":
|
200 |
+
warnings.warn(warning_message)
|
201 |
+
return predictor, dataloaders, config
|
202 |
+
|
203 |
+
else:
|
204 |
+
if cfg.model_type == "interaction_biased":
|
205 |
+
dataloaders = WaymoDataloaders(cfg)
|
206 |
+
else:
|
207 |
+
[data_train, data_val, data_test] = load_create_dataset(cfg)
|
208 |
+
dataloaders = SceneDataLoaders(
|
209 |
+
state_dim=cfg.state_dim,
|
210 |
+
num_steps=cfg.num_steps,
|
211 |
+
num_steps_future=cfg.num_steps_future,
|
212 |
+
batch_size=cfg.batch_size,
|
213 |
+
data_train=data_train,
|
214 |
+
data_val=data_val,
|
215 |
+
data_test=data_test,
|
216 |
+
num_workers=cfg.num_workers,
|
217 |
+
)
|
218 |
+
|
219 |
+
predictor = get_predictor(cfg, dataloaders.unnormalize_trajectory)
|
220 |
+
return predictor, dataloaders, cfg
|
risk_biased/utils/loss.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.distributions import MultivariateNormal
|
6 |
+
|
7 |
+
|
8 |
+
def reconstruction_loss(
|
9 |
+
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
|
10 |
+
):
|
11 |
+
"""
|
12 |
+
pred (Tensor): (..., time, [x,y,(a),(vx,vy)])
|
13 |
+
truth (Tensor): (..., time, [x,y,(a),(vx,vy)])
|
14 |
+
mask_loss (Tensor): (..., time) Defaults to None.
|
15 |
+
"""
|
16 |
+
min_feat_shape = min(pred.shape[-1], truth.shape[-1])
|
17 |
+
if min_feat_shape == 3:
|
18 |
+
assert pred.shape[-1] == truth.shape[-1]
|
19 |
+
return reconstruction_loss(
|
20 |
+
pred[..., :2], truth[..., :2], mask_loss
|
21 |
+
) + reconstruction_loss(
|
22 |
+
torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1),
|
23 |
+
torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1),
|
24 |
+
mask_loss,
|
25 |
+
)
|
26 |
+
elif min_feat_shape >= 5:
|
27 |
+
assert pred.shape[-1] <= truth.shape[-1]
|
28 |
+
v_norm = torch.sum(torch.square(truth[..., 3:5]), -1, keepdim=True)
|
29 |
+
v_mask = v_norm > 1
|
30 |
+
return (
|
31 |
+
reconstruction_loss(pred[..., :2], truth[..., :2], mask_loss)
|
32 |
+
+ reconstruction_loss(
|
33 |
+
torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1)
|
34 |
+
* v_mask,
|
35 |
+
torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1)
|
36 |
+
* v_mask,
|
37 |
+
mask_loss,
|
38 |
+
)
|
39 |
+
+ reconstruction_loss(pred[..., 3:5], truth[..., 3:5], mask_loss)
|
40 |
+
)
|
41 |
+
elif min_feat_shape == 2:
|
42 |
+
if mask_loss is None:
|
43 |
+
return torch.mean(
|
44 |
+
torch.sqrt(
|
45 |
+
torch.sum(
|
46 |
+
torch.square(pred[..., :2] - truth[..., :2]), -1
|
47 |
+
).clamp_min(1e-6)
|
48 |
+
)
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
assert mask_loss.any()
|
52 |
+
mask_loss = mask_loss.float()
|
53 |
+
return torch.sum(
|
54 |
+
torch.sqrt(
|
55 |
+
torch.sum(
|
56 |
+
torch.square(pred[..., :2] - truth[..., :2]), -1
|
57 |
+
).clamp_min(1e-6)
|
58 |
+
)
|
59 |
+
* mask_loss
|
60 |
+
) / torch.sum(mask_loss).clamp_min(1)
|
61 |
+
|
62 |
+
|
63 |
+
def map_penalized_reconstruction_loss(
|
64 |
+
pred: torch.Tensor,
|
65 |
+
truth: torch.Tensor,
|
66 |
+
map: torch.Tensor,
|
67 |
+
mask_map: torch.Tensor,
|
68 |
+
mask_loss: Optional[torch.Tensor] = None,
|
69 |
+
map_importance: float = 0.1,
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
pred (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
|
73 |
+
truth (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
|
74 |
+
map (Tensor): (batch_size, num_objects, object_sequence_length, [x, y, ...])
|
75 |
+
mask_map (Tensor): (...)
|
76 |
+
mask_loss (Tensor): (..., time) Defaults to None.
|
77 |
+
|
78 |
+
"""
|
79 |
+
# b, a, o, s, f b, a, o, t, s, f
|
80 |
+
map_distance, _ = (
|
81 |
+
(map[:, None, :, :, :2] - pred[:, :, None, -1, None, :2])
|
82 |
+
.square()
|
83 |
+
.sum(-1)
|
84 |
+
.min(2)
|
85 |
+
)
|
86 |
+
map_distance = map_distance.sqrt().clamp(0.5, 3)
|
87 |
+
if mask_map is not None:
|
88 |
+
map_loss = (map_distance * mask_loss[..., -1:]).sum() / mask_loss[..., -1].sum()
|
89 |
+
else:
|
90 |
+
map_loss = map_distance.mean()
|
91 |
+
|
92 |
+
rec_loss = reconstruction_loss(pred, truth, mask_loss)
|
93 |
+
|
94 |
+
return rec_loss + map_importance * map_loss
|
95 |
+
|
96 |
+
|
97 |
+
def cce_loss_with_logits(pred_logits: torch.Tensor, truth: torch.Tensor):
|
98 |
+
pred_log = pred_logits.log_softmax(-1)
|
99 |
+
return -(pred_log * truth).sum(-1).mean()
|
100 |
+
|
101 |
+
|
102 |
+
def risk_loss_function(
|
103 |
+
pred: torch.Tensor,
|
104 |
+
truth: torch.Tensor,
|
105 |
+
mask: torch.Tensor,
|
106 |
+
factor: float = 100.0,
|
107 |
+
) -> torch.Tensor:
|
108 |
+
"""
|
109 |
+
Loss function for the risk comparison. This is assymetric because it is preferred that the model over-estimates
|
110 |
+
the risk rather than under-estimate it.
|
111 |
+
Args:
|
112 |
+
pred: (same_shape) The predicted risks
|
113 |
+
truth: (same_shape) The reference risks to match
|
114 |
+
mask: (same_shape) A mask with 1 where the loss should be computed and 0 elsewhere.
|
115 |
+
approximate_mean_error: An approximation of the mean error obtained after training. The lower this value,
|
116 |
+
the greater the intensity of the assymetry.
|
117 |
+
Returns:
|
118 |
+
Scalar loss value
|
119 |
+
"""
|
120 |
+
error = pred - truth
|
121 |
+
error = error * factor
|
122 |
+
error = torch.where(error > 1, (error + 1e-6).log(), error.abs())
|
123 |
+
error = (error * mask).sum() / mask.sum()
|
124 |
+
return error
|
risk_biased/utils/metrics.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def FDE(
|
7 |
+
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
|
8 |
+
):
|
9 |
+
"""
|
10 |
+
pred (Tensor): (..., time, xy)
|
11 |
+
truth (Tensor): (..., time, xy)
|
12 |
+
mask_loss (Tensor): (..., time) Defaults to None.
|
13 |
+
"""
|
14 |
+
if mask_loss is None:
|
15 |
+
return torch.mean(
|
16 |
+
torch.sqrt(
|
17 |
+
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
|
18 |
+
)
|
19 |
+
)
|
20 |
+
else:
|
21 |
+
mask_loss = mask_loss.float()
|
22 |
+
return torch.sum(
|
23 |
+
torch.sqrt(
|
24 |
+
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
|
25 |
+
)
|
26 |
+
* mask_loss[..., -1]
|
27 |
+
) / torch.sum(mask_loss[..., -1]).clamp_min(1)
|
28 |
+
|
29 |
+
|
30 |
+
def ADE(
|
31 |
+
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
pred (Tensor): (..., time, xy)
|
35 |
+
truth (Tensor): (..., time, xy)
|
36 |
+
mask_loss (Tensor): (..., time) Defaults to None.
|
37 |
+
"""
|
38 |
+
if mask_loss is None:
|
39 |
+
return torch.mean(
|
40 |
+
torch.sqrt(
|
41 |
+
torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
|
42 |
+
)
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
mask_loss = mask_loss.float()
|
46 |
+
return torch.sum(
|
47 |
+
torch.sqrt(
|
48 |
+
torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
|
49 |
+
)
|
50 |
+
* mask_loss
|
51 |
+
) / torch.sum(mask_loss).clamp_min(1)
|
52 |
+
|
53 |
+
|
54 |
+
def minFDE(
|
55 |
+
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
pred (Tensor): (..., n_samples, time, xy)
|
59 |
+
truth (Tensor): (..., time, xy)
|
60 |
+
mask_loss (Tensor): (..., time) Defaults to None.
|
61 |
+
"""
|
62 |
+
if mask_loss is None:
|
63 |
+
min_distances, _ = torch.min(
|
64 |
+
torch.sqrt(
|
65 |
+
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
|
66 |
+
),
|
67 |
+
-1,
|
68 |
+
)
|
69 |
+
return torch.mean(min_distances)
|
70 |
+
else:
|
71 |
+
mask_loss = mask_loss[..., -1].float()
|
72 |
+
final_distances = torch.sqrt(
|
73 |
+
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
|
74 |
+
)
|
75 |
+
max_final_distance = torch.max(final_distances * mask_loss)
|
76 |
+
min_distances, _ = torch.min(
|
77 |
+
final_distances + max_final_distance * (1 - mask_loss), -1
|
78 |
+
)
|
79 |
+
return torch.sum(min_distances * mask_loss.any(-1)) / torch.sum(
|
80 |
+
mask_loss.any(-1)
|
81 |
+
).clamp_min(1)
|
risk_biased/utils/planner_utils.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
from numpy import isin
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from risk_biased.mpc_planner.planner_cost import TrackingCost
|
10 |
+
from risk_biased.utils.cost import BaseCostTorch
|
11 |
+
from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator
|
12 |
+
|
13 |
+
|
14 |
+
def get_rotation_matrix(angle, device):
|
15 |
+
c = torch.cos(angle)
|
16 |
+
s = torch.sin(angle)
|
17 |
+
rot_matrix = torch.stack(
|
18 |
+
(torch.stack((c, s), -1), torch.stack((-s, c), -1)), -1
|
19 |
+
).to(device)
|
20 |
+
return rot_matrix
|
21 |
+
|
22 |
+
|
23 |
+
class AbstractState(ABC):
|
24 |
+
"""
|
25 |
+
State representation using an underlying tensor. Position, Velocity, and Angle can be accessed.
|
26 |
+
"""
|
27 |
+
|
28 |
+
@property
|
29 |
+
@abstractmethod
|
30 |
+
def position(self) -> torch.Tensor:
|
31 |
+
"""Extract position information from the state tensor
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
position_tensor of size (..., 2)
|
35 |
+
"""
|
36 |
+
|
37 |
+
@property
|
38 |
+
@abstractmethod
|
39 |
+
def velocity(self) -> torch.Tensor:
|
40 |
+
"""Extract velocity information from the state tensor
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
velocity_tensor of size (..., 2)
|
44 |
+
"""
|
45 |
+
|
46 |
+
@property
|
47 |
+
@abstractmethod
|
48 |
+
def angle(self) -> torch.Tensor:
|
49 |
+
"""Extract velocity information from the state tensor
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
velocity_tensor of size (..., 1)
|
53 |
+
"""
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def get_states(self, dim: int) -> torch.Tensor:
|
57 |
+
"""Return the underlying states tensor with dim 2, 4 or 5 ([x, y], [x, y, vx, vy], or [x, y, angle, vx, vy])."""
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def rotate(self, angle: float, in_place: bool) -> AbstractState:
|
61 |
+
"""Rotate the state by the given angle
|
62 |
+
Args:
|
63 |
+
angle: in radiants
|
64 |
+
in_place: wether to change the object itself or return a rotated copy
|
65 |
+
Returns:
|
66 |
+
rotated self or rotated copy of self
|
67 |
+
"""
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def translate(self, translation: torch.Tensor, in_place: bool) -> AbstractState:
|
71 |
+
"""Translate the state by the given tranlation
|
72 |
+
Args:
|
73 |
+
translation: translation vector in 2 dimensions
|
74 |
+
in_place: wether to change the object itself or return a rotated copy
|
75 |
+
"""
|
76 |
+
|
77 |
+
# Define overloading operators to behave as a tensor for some operations
|
78 |
+
def __getitem__(self, key) -> AbstractState:
|
79 |
+
"""
|
80 |
+
Use get item on the underlying tensor to get the item at the given key.
|
81 |
+
Allways returns a velocity state so that if the underlying time sequence is reduced to one step, the velocity is still accessible.
|
82 |
+
"""
|
83 |
+
if isinstance(key, int):
|
84 |
+
key = (key, Ellipsis, slice(None, None, None))
|
85 |
+
elif Ellipsis not in key:
|
86 |
+
key = (*key, Ellipsis, slice(None, None, None))
|
87 |
+
else:
|
88 |
+
key = (*key, slice(None, None, None))
|
89 |
+
|
90 |
+
return to_state(
|
91 |
+
torch.cat(
|
92 |
+
(
|
93 |
+
self.position[key],
|
94 |
+
self.velocity[key],
|
95 |
+
),
|
96 |
+
dim=-1,
|
97 |
+
),
|
98 |
+
self.dt,
|
99 |
+
)
|
100 |
+
|
101 |
+
@property
|
102 |
+
def shape(self):
|
103 |
+
return self._states.shape[:-1]
|
104 |
+
|
105 |
+
|
106 |
+
def to_state(in_tensor: torch.Tensor, dt: float) -> AbstractState:
|
107 |
+
if in_tensor.shape[-1] == 2:
|
108 |
+
return PositionSequenceState(in_tensor, dt)
|
109 |
+
elif in_tensor.shape[-1] == 4:
|
110 |
+
return PositionVelocityState(in_tensor, dt)
|
111 |
+
else:
|
112 |
+
assert in_tensor.shape[-1] > 4
|
113 |
+
return PositionAngleVelocityState(in_tensor, dt)
|
114 |
+
|
115 |
+
|
116 |
+
class PositionSequenceState(AbstractState):
|
117 |
+
"""
|
118 |
+
State representation with an underlying tensor defining only positions.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, states: torch.Tensor, dt: float) -> None:
|
122 |
+
super().__init__()
|
123 |
+
assert (
|
124 |
+
states.shape[-1] == 2
|
125 |
+
) # Check that the input tensor defines only the position
|
126 |
+
assert (
|
127 |
+
states.ndim > 1 and states.shape[-2] > 1
|
128 |
+
) # Check that the input tensor defines a sequence of positions (otherwise velocity cannot be computed)
|
129 |
+
self.dt = dt
|
130 |
+
self._states = states.clone()
|
131 |
+
|
132 |
+
@property
|
133 |
+
def position(self) -> torch.Tensor:
|
134 |
+
return self._states
|
135 |
+
|
136 |
+
@property
|
137 |
+
def velocity(self) -> torch.Tensor:
|
138 |
+
vel = (self._states[..., 1:, :] - self._states[..., :-1, :]) / self.dt
|
139 |
+
vel = torch.cat((vel[..., 0:1, :], vel), dim=-2)
|
140 |
+
return vel.clone()
|
141 |
+
|
142 |
+
@property
|
143 |
+
def angle(self) -> torch.Tensor:
|
144 |
+
vel = self.velocity
|
145 |
+
angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1])
|
146 |
+
return angle
|
147 |
+
|
148 |
+
def get_states(self, dim: int = 2) -> torch.Tensor:
|
149 |
+
if dim == 2:
|
150 |
+
return self._states.clone()
|
151 |
+
elif dim == 4:
|
152 |
+
return torch.cat((self._states.clone(), self.velocity), dim=-1)
|
153 |
+
elif dim == 5:
|
154 |
+
return torch.cat((self._states.clone(), self.angle, self.velocity), dim=-1)
|
155 |
+
else:
|
156 |
+
raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")
|
157 |
+
|
158 |
+
def rotate(self, angle: float, in_place: bool = False) -> PositionSequenceState:
|
159 |
+
"""Rotate the state by the given angle in radiants"""
|
160 |
+
rot_matrix = get_rotation_matrix(angle, self._states.device)
|
161 |
+
if in_place:
|
162 |
+
self._states = (rot_matrix @ self._states.unsqueeze(-1)).squeeze(-1)
|
163 |
+
return self
|
164 |
+
else:
|
165 |
+
return to_state(
|
166 |
+
(rot_matrix @ self._states.unsqueeze(-1).clone()).squeeze(-1), self.dt
|
167 |
+
)
|
168 |
+
|
169 |
+
def translate(
|
170 |
+
self, translation: torch.Tensor, in_place: bool = False
|
171 |
+
) -> PositionSequenceState:
|
172 |
+
"""Translate the state by the given tranlation"""
|
173 |
+
if in_place:
|
174 |
+
self._states[..., :2] += translation.expand_as(self._states[..., :2])
|
175 |
+
return self
|
176 |
+
else:
|
177 |
+
return to_state(
|
178 |
+
self._states[..., :2].clone()
|
179 |
+
+ translation.expand_as(self._states[..., :2]),
|
180 |
+
self.dt,
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
class PositionVelocityState(AbstractState):
|
185 |
+
"""
|
186 |
+
State representation with an underlying tensor defining position and velocity.
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, states: torch.Tensor, dt) -> None:
|
190 |
+
super().__init__()
|
191 |
+
assert states.shape[-1] == 4
|
192 |
+
self._states = states.clone()
|
193 |
+
self.dt = dt
|
194 |
+
|
195 |
+
@property
|
196 |
+
def position(self) -> torch.Tensor:
|
197 |
+
return self._states[..., :2]
|
198 |
+
|
199 |
+
@property
|
200 |
+
def velocity(self) -> torch.Tensor:
|
201 |
+
return self._states[..., 2:4]
|
202 |
+
|
203 |
+
@property
|
204 |
+
def angle(self) -> torch.Tensor:
|
205 |
+
vel = self.velocity
|
206 |
+
angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1])
|
207 |
+
return angle
|
208 |
+
|
209 |
+
def get_states(self, dim: int = 4) -> torch.Tensor:
|
210 |
+
if dim == 2:
|
211 |
+
return self._states[..., :2].clone()
|
212 |
+
elif dim == 4:
|
213 |
+
return self._states.clone()
|
214 |
+
elif dim == 5:
|
215 |
+
return torch.cat(
|
216 |
+
(
|
217 |
+
self._states[..., :2].clone(),
|
218 |
+
self.angle,
|
219 |
+
self._states[..., 2:].clone(),
|
220 |
+
),
|
221 |
+
dim=-1,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")
|
225 |
+
|
226 |
+
def rotate(
|
227 |
+
self, angle: torch.Tensor, in_place: bool = False
|
228 |
+
) -> PositionVelocityState:
|
229 |
+
"""Rotate the state by the given angle in radiants"""
|
230 |
+
rot_matrix = get_rotation_matrix(angle, self._states.device)
|
231 |
+
rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1)
|
232 |
+
rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1)
|
233 |
+
if in_place:
|
234 |
+
self._states = torch.cat((rotated_pos, rotated_vel), dim=-1)
|
235 |
+
return self
|
236 |
+
else:
|
237 |
+
return to_state(torch.cat((rotated_pos, rotated_vel), dim=-1), self.dt)
|
238 |
+
|
239 |
+
def translate(
|
240 |
+
self, translation: torch.Tensor, in_place: bool = False
|
241 |
+
) -> PositionVelocityState:
|
242 |
+
"""Translate the state by the given tranlation"""
|
243 |
+
if in_place:
|
244 |
+
self._states[..., :2] += translation.expand_as(self._states[..., :2])
|
245 |
+
return self
|
246 |
+
else:
|
247 |
+
return to_state(
|
248 |
+
torch.cat(
|
249 |
+
(
|
250 |
+
self._states[..., :2].clone()
|
251 |
+
+ translation.expand_as(self._states[..., :2]),
|
252 |
+
self._states[..., 2:].clone(),
|
253 |
+
),
|
254 |
+
dim=-1,
|
255 |
+
),
|
256 |
+
self.dt,
|
257 |
+
)
|
258 |
+
|
259 |
+
|
260 |
+
class PositionAngleVelocityState(AbstractState):
|
261 |
+
"""
|
262 |
+
State representation with an underlying tensor representing position angle and velocity.
|
263 |
+
"""
|
264 |
+
|
265 |
+
def __init__(self, states: torch.Tensor, dt: float) -> None:
|
266 |
+
super().__init__()
|
267 |
+
assert states.shape[-1] == 5
|
268 |
+
self._states = states.clone()
|
269 |
+
self.dt = dt
|
270 |
+
|
271 |
+
@property
|
272 |
+
def position(self) -> torch.Tensor:
|
273 |
+
return self._states[..., :2].clone()
|
274 |
+
|
275 |
+
@property
|
276 |
+
def velocity(self) -> torch.Tensor:
|
277 |
+
return self._states[..., 3:5].clone()
|
278 |
+
|
279 |
+
@property
|
280 |
+
def angle(self) -> torch.Tensor:
|
281 |
+
return self._states[..., 2:3].clone()
|
282 |
+
|
283 |
+
def get_states(self, dim: int = 5) -> torch.Tensor:
|
284 |
+
if dim == 2:
|
285 |
+
return self._states[..., :2].clone()
|
286 |
+
elif dim == 4:
|
287 |
+
return torch.cat(
|
288 |
+
(self._states[..., :2].clone(), self._states[..., 3:].clone()), dim=-1
|
289 |
+
)
|
290 |
+
elif dim == 5:
|
291 |
+
return self._states.clone()
|
292 |
+
else:
|
293 |
+
raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")
|
294 |
+
|
295 |
+
def rotate(
|
296 |
+
self, angle: float, in_place: bool = False
|
297 |
+
) -> PositionAngleVelocityState:
|
298 |
+
"""Rotate the state by the given angle in radiants"""
|
299 |
+
rot_matrix = get_rotation_matrix(angle, self._states.device)
|
300 |
+
rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1)
|
301 |
+
rotated_angle = self.angle + angle
|
302 |
+
rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1)
|
303 |
+
if in_place:
|
304 |
+
self._states = torch.cat(rotated_pos, rotated_angle, rotated_vel, -1)
|
305 |
+
return self
|
306 |
+
else:
|
307 |
+
return to_state(
|
308 |
+
torch.cat(rotated_pos, rotated_angle, rotated_vel, -1), self.dt
|
309 |
+
)
|
310 |
+
|
311 |
+
def translate(
|
312 |
+
self, translation: torch.Tensor, in_place: bool = False
|
313 |
+
) -> PositionAngleVelocityState:
|
314 |
+
"""Translate the state by the given tranlation"""
|
315 |
+
if in_place:
|
316 |
+
self._states[..., :2] += translation.expand_as(self._states[..., :2])
|
317 |
+
return self
|
318 |
+
else:
|
319 |
+
return to_state(
|
320 |
+
torch.cat(
|
321 |
+
(
|
322 |
+
self._states[..., :2]
|
323 |
+
+ translation.expand_as(self._states[..., :2]),
|
324 |
+
self._states[..., 2:],
|
325 |
+
),
|
326 |
+
dim=-1,
|
327 |
+
),
|
328 |
+
self.dt,
|
329 |
+
)
|
330 |
+
|
331 |
+
|
332 |
+
def get_interaction_cost(
|
333 |
+
ego_state_future: AbstractState,
|
334 |
+
ado_state_future_samples: AbstractState,
|
335 |
+
interaction_cost_function: BaseCostTorch,
|
336 |
+
) -> torch.Tensor:
|
337 |
+
"""Computes interaction cost samples from predicted ado future trajectories and a batch of ego
|
338 |
+
future trajectories
|
339 |
+
|
340 |
+
Args:
|
341 |
+
ego_state_future: ((num_control_samples), num_agents, num_steps_future) ego state future
|
342 |
+
future trajectory
|
343 |
+
ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
|
344 |
+
predicted ado state trajectory samples
|
345 |
+
interaction_cost_function: interaction cost function between ego and (stochastic) ado
|
346 |
+
dt: time differential between two discrete timesteps in seconds
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
(num_control_samples, num_agents, num_prediction_samples) interaction cost tensor
|
350 |
+
"""
|
351 |
+
if len(ego_state_future.shape) == 2:
|
352 |
+
x_ego = ego_state_future.position.unsqueeze(0)
|
353 |
+
v_ego = ego_state_future.velocity.unsqueeze(0)
|
354 |
+
else:
|
355 |
+
x_ego = ego_state_future.position
|
356 |
+
v_ego = ego_state_future.velocity
|
357 |
+
|
358 |
+
num_control_samples = ego_state_future.shape[0]
|
359 |
+
ado_position_future_samples = ado_state_future_samples.position.unsqueeze(0).expand(
|
360 |
+
num_control_samples, -1, -1, -1, -1
|
361 |
+
)
|
362 |
+
|
363 |
+
v_samples = ado_state_future_samples.velocity.unsqueeze(0).expand(
|
364 |
+
num_control_samples, -1, -1, -1, -1
|
365 |
+
)
|
366 |
+
|
367 |
+
interaction_cost, _ = interaction_cost_function(
|
368 |
+
x1=x_ego.unsqueeze(1),
|
369 |
+
x2=ado_position_future_samples,
|
370 |
+
v1=v_ego.unsqueeze(1),
|
371 |
+
v2=v_samples,
|
372 |
+
)
|
373 |
+
return interaction_cost.permute(0, 2, 1)
|
374 |
+
|
375 |
+
|
376 |
+
def evaluate_risk(
|
377 |
+
risk_level: float,
|
378 |
+
cost: torch.Tensor,
|
379 |
+
weights: torch.Tensor,
|
380 |
+
risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
|
381 |
+
) -> torch.Tensor:
|
382 |
+
"""Returns a risk tensor given costs and optionally a risk level
|
383 |
+
|
384 |
+
Args:
|
385 |
+
risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation will be
|
386 |
+
returned. Defaults to 0.0.
|
387 |
+
cost: (num_control_samples, num_agents, num_prediction_samples) cost tensor
|
388 |
+
weights: (num_control_samples, num_agents, num_prediction_samples) probability weight of the cost tensor
|
389 |
+
risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
(num_control_samples, num_agents) risk tensor
|
393 |
+
"""
|
394 |
+
num_control_samples, num_agents, _ = cost.shape
|
395 |
+
|
396 |
+
if risk_level == 0.0:
|
397 |
+
risk = cost.mean(dim=-1)
|
398 |
+
else:
|
399 |
+
assert risk_estimator is not None, "no risk estimator is specified"
|
400 |
+
risk = risk_estimator(
|
401 |
+
risk_level * torch.ones(num_control_samples, num_agents),
|
402 |
+
cost,
|
403 |
+
weights=weights,
|
404 |
+
)
|
405 |
+
return risk
|
406 |
+
|
407 |
+
|
408 |
+
def evaluate_control_sequence(
|
409 |
+
control_sequence: torch.Tensor,
|
410 |
+
dynamics_model,
|
411 |
+
ego_state_history: AbstractState,
|
412 |
+
ego_state_target_trajectory: AbstractState,
|
413 |
+
ado_state_future_samples: AbstractState,
|
414 |
+
sample_weights: torch.Tensor,
|
415 |
+
interaction_cost_function: BaseCostTorch,
|
416 |
+
tracking_cost_function: TrackingCost,
|
417 |
+
risk_level: float = 0.0,
|
418 |
+
risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
|
419 |
+
) -> Tuple[float, float]:
|
420 |
+
"""Returns the risk and tracking cost evaluation of the given control sequence
|
421 |
+
|
422 |
+
Args:
|
423 |
+
control_sequence: (num_steps_future, control_dim) tensor of control sequence
|
424 |
+
dynamics_model: dynamics model for control
|
425 |
+
ego_state_target_trajectory: (num_steps_future) tensor of ego target
|
426 |
+
state trajectory
|
427 |
+
ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
|
428 |
+
of predicted ado trajectory samples states
|
429 |
+
sample_weights: (num_prediction_samples, num_agents) tensor of probability weights of the samples
|
430 |
+
intraction_cost_function: interaction cost function between ego and (stochastic) ado
|
431 |
+
tracking_cost_function: deterministic tracking cost that does not involve ado
|
432 |
+
risk_level: risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation
|
433 |
+
is used. Defaults to 0.0.
|
434 |
+
risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None.
|
435 |
+
|
436 |
+
Returns:
|
437 |
+
tuple of (interaction risk, tracking_cost)
|
438 |
+
"""
|
439 |
+
ego_state_current = ego_state_history[..., -1]
|
440 |
+
ego_state_future = dynamics_model.simulate(ego_state_current, control_sequence)
|
441 |
+
# state starts with x, y, angle, vx, vy
|
442 |
+
tracking_cost = tracking_cost_function(
|
443 |
+
ego_state_future.position,
|
444 |
+
ego_state_target_trajectory.position,
|
445 |
+
ego_state_target_trajectory.velocity,
|
446 |
+
)
|
447 |
+
|
448 |
+
interaction_cost = get_interaction_cost(
|
449 |
+
ego_state_future,
|
450 |
+
ado_state_future_samples,
|
451 |
+
interaction_cost_function,
|
452 |
+
)
|
453 |
+
|
454 |
+
interaction_risk = evaluate_risk(
|
455 |
+
risk_level,
|
456 |
+
interaction_cost,
|
457 |
+
sample_weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost),
|
458 |
+
risk_estimator,
|
459 |
+
)
|
460 |
+
|
461 |
+
# TODO: averaging over agents but we might want to reduce a different way
|
462 |
+
return (interaction_risk.mean().item(), tracking_cost.mean().item())
|
risk_biased/utils/risk.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
class AbstractMonteCarloRiskEstimator(ABC):
|
11 |
+
"""Abstract class for Monte Carlo estimation of risk objectives"""
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def __call__(self, risk_level: Tensor, cost: Tensor) -> Tensor:
|
15 |
+
"""Computes and returns the risk objective estimated on the cost tensor
|
16 |
+
|
17 |
+
Args:
|
18 |
+
risk_level: (batch_size,) tensor of risk-level at which the risk objective is computed
|
19 |
+
cost: (batch_size, n_samples) tensor of cost samples
|
20 |
+
Returns:
|
21 |
+
risk tensor of size (batch_size,)
|
22 |
+
"""
|
23 |
+
|
24 |
+
|
25 |
+
class EntropicRiskEstimator(AbstractMonteCarloRiskEstimator):
|
26 |
+
"""Monte Carlo estimator for the entropic risk objective.
|
27 |
+
This estimator computes the entropic risk as 1/risk_level * log( mean( exp(risk_level * cost), 1))
|
28 |
+
However, this is unstable.
|
29 |
+
When risk_level is large, the logsumexp trick is used.
|
30 |
+
When risk_level is small, it computes entropic_risk for small values of risk_level as the second order Taylor expansion instead.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
eps: Risk-level threshold to switch between logsumexp and Taylor expansion. Defaults to 1e-4.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, eps: float = 1e-4) -> None:
|
37 |
+
self.eps = eps
|
38 |
+
|
39 |
+
def __call__(self, risk_level: Tensor, cost: Tensor, weights: Tensor) -> Tensor:
|
40 |
+
"""Computes and returns the entropic risk estimated on the cost tensor
|
41 |
+
|
42 |
+
Args:
|
43 |
+
risk_level: (batch_size, n_agents,) tensor of risk-level at which the risk objective is computed
|
44 |
+
cost: (batch_size, n_agents, n_samples) cost tensor
|
45 |
+
weights: (batch_size, n_agents, n_samples) tensor of weights for the cost samples
|
46 |
+
Returns:
|
47 |
+
entropic risk tensor of size (batch_size,)
|
48 |
+
"""
|
49 |
+
weights = weights / weights.sum(dim=-1, keepdim=True)
|
50 |
+
batch_size, n_agents, n_samples = cost.shape
|
51 |
+
entropic_risk_cost_large_sigma = (
|
52 |
+
((risk_level.view(batch_size, n_agents, 1) * cost).exp() * weights)
|
53 |
+
.sum(-1)
|
54 |
+
.log()
|
55 |
+
) / risk_level
|
56 |
+
|
57 |
+
mean = (cost * weights).sum(dim=-1)
|
58 |
+
var = (cost**2 * weights).sum(dim=-1) - mean**2
|
59 |
+
|
60 |
+
var, mean = torch.var_mean(cost, -1)
|
61 |
+
entropic_risk_cost_small_sigma = mean + 0.5 * risk_level * var
|
62 |
+
|
63 |
+
return torch.where(
|
64 |
+
torch.abs(risk_level) > self.eps,
|
65 |
+
entropic_risk_cost_large_sigma,
|
66 |
+
entropic_risk_cost_small_sigma,
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class CVaREstimator(AbstractMonteCarloRiskEstimator):
|
71 |
+
"""Monte Carlo estimator for the conditional value-at-risk objective.
|
72 |
+
This estimator is proposed in the following references, and shown to be consistent.
|
73 |
+
- Hong et al. (2014), "Monte Carlo Methods for Value-at-Risk and Conditional Value-at-Risk: A Review"
|
74 |
+
- Traindade et al. (2007), "Financial prediction with constrained tail risk"
|
75 |
+
When risk_level is larger than 1 - eps, it falls back to the max operator
|
76 |
+
|
77 |
+
Args:
|
78 |
+
Args:
|
79 |
+
eps: Risk-level threshold to switch between CVaR and Max. Defaults to 1e-4.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, eps: float = 1e-4) -> None:
|
83 |
+
self.eps = eps
|
84 |
+
|
85 |
+
def __call__(self, risk_level: Tensor, cost: Tensor, weights: Tensor) -> Tensor:
|
86 |
+
"""Computes and returns the conditional value-at-risk estimated on the cost tensor
|
87 |
+
|
88 |
+
Args:
|
89 |
+
risk_level: (batch_size, n_agents) tensor of risk-level in [0, 1] at which the CVaR risk is computed
|
90 |
+
cost: (batch_size, n_agents, n_samples) cost tensor
|
91 |
+
weights: (batch_size, n_agents, n_samples) tensor of weights for the cost samples
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
conditional value-at-risk tensor of size (batch_size, n_agents)
|
95 |
+
"""
|
96 |
+
assert risk_level.shape[0] == cost.shape[0]
|
97 |
+
assert risk_level.shape[1] == cost.shape[1]
|
98 |
+
if weights is None:
|
99 |
+
weights = torch.ones_like(cost) / cost.shape[-1]
|
100 |
+
else:
|
101 |
+
weights = weights / weights.sum(dim=-1, keepdim=True)
|
102 |
+
if not (torch.all(0.0 <= risk_level) and torch.all(risk_level <= 1.0)):
|
103 |
+
warnings.warn(
|
104 |
+
"risk_level is defined only between 0.0 and 1.0 for CVaR. Exceeded values will be clamped."
|
105 |
+
)
|
106 |
+
risk_level = torch.clamp(risk_level, min=0.0, max=1.0)
|
107 |
+
|
108 |
+
cvar_risk_high = cost.max(dim=-1).values
|
109 |
+
|
110 |
+
sorted_indices = torch.argsort(cost, dim=-1)
|
111 |
+
# cost_sorted = cost.sort(dim=-1, descending=False).values
|
112 |
+
cost_sorted = torch.gather(cost, -1, sorted_indices)
|
113 |
+
weights_sorted = torch.gather(weights, -1, sorted_indices)
|
114 |
+
idx_to_choose = torch.argmax(
|
115 |
+
(weights_sorted.cumsum(dim=-1) >= risk_level.unsqueeze(-1)).float(), -1
|
116 |
+
)
|
117 |
+
|
118 |
+
value_at_risk_mc = cost_sorted.gather(-1, idx_to_choose.unsqueeze(-1)).squeeze(
|
119 |
+
-1
|
120 |
+
)
|
121 |
+
|
122 |
+
# weights_at_risk_mc = 1 - weights_sorted.cumsum(-1).gather(
|
123 |
+
# -1, idx_to_choose.unsqueeze(-1)
|
124 |
+
# ).squeeze(-1)
|
125 |
+
# cvar_risk_mc = value_at_risk_mc + (
|
126 |
+
# (torch.relu(cost - value_at_risk_mc.unsqueeze(-1)) * weights).sum(dim=-1)
|
127 |
+
# / weights_at_risk_mc
|
128 |
+
# )
|
129 |
+
# cvar = torch.where(weights_at_risk_mc < self.eps, cvar_risk_high, cvar_risk_mc)
|
130 |
+
|
131 |
+
cvar_risk_mc = value_at_risk_mc + 1 / (1 - risk_level) * (
|
132 |
+
(torch.relu(cost - value_at_risk_mc.unsqueeze(-1)) * weights).sum(dim=-1)
|
133 |
+
)
|
134 |
+
cvar = torch.where(risk_level > 1 - self.eps, cvar_risk_high, cvar_risk_mc)
|
135 |
+
return cvar
|
136 |
+
|
137 |
+
|
138 |
+
def get_risk_estimator(estimator_params: dict) -> AbstractMonteCarloRiskEstimator:
|
139 |
+
"""Function that returns the Monte Carlo risk estimator hat matches the given parameters.
|
140 |
+
Tries to give a comprehensive feedback if the parameters are not recognized and raise an error.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
Risk estimator should be one of the following types (with different parameter values as desired) :
|
144 |
+
{"type": "entropic", "eps": 1e-4},
|
145 |
+
{"type": "cvar", "eps": 1e-4}
|
146 |
+
|
147 |
+
Raises:
|
148 |
+
RuntimeError: If the given parameter dictionary does not match one of the expected formats, raise a comprehensive error.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
A risk estimator matching the given parameters.
|
152 |
+
"""
|
153 |
+
known_types = ["entropic", "cvar"]
|
154 |
+
try:
|
155 |
+
if estimator_params["type"].lower() == "entropic":
|
156 |
+
expected_params = inspect.getfullargspec(EntropicRiskEstimator)[0][1:]
|
157 |
+
return EntropicRiskEstimator(estimator_params["eps"])
|
158 |
+
elif estimator_params["type"].lower() == "cvar":
|
159 |
+
expected_params = inspect.getfullargspec(CVaREstimator)[0][1:]
|
160 |
+
return CVaREstimator(estimator_params["eps"])
|
161 |
+
else:
|
162 |
+
raise RuntimeError(
|
163 |
+
f"Risk estimator '{estimator_params['type']}' is unknown. It should be one of {known_types}."
|
164 |
+
)
|
165 |
+
except KeyError:
|
166 |
+
if "type" in estimator_params:
|
167 |
+
raise RuntimeError(
|
168 |
+
f"""The estimator '{estimator_params['type']}' is known but the given parameters
|
169 |
+
{estimator_params} do not match the expected parameters {expected_params}."""
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
raise RuntimeError(
|
173 |
+
f"""The given estimator parameters {estimator_params} do not define the estimator
|
174 |
+
type in the field 'type'. Please add a field 'type' and set it to one of the
|
175 |
+
handeled types: {known_types}."""
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
class AbstractRiskLevelSampler(ABC):
|
180 |
+
"""Abstract class for a risk-level sampler for training and evaluating risk-biased predictors"""
|
181 |
+
|
182 |
+
@abstractmethod
|
183 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
184 |
+
"""Returns a tensor of size batch_size with sampled risk-level values
|
185 |
+
|
186 |
+
Args:
|
187 |
+
batch_size: number of elements in the out tensor
|
188 |
+
device: device of the output tensor
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
A tensor of shape(batch_size,) filled with sampled risk values
|
192 |
+
"""
|
193 |
+
|
194 |
+
@abstractmethod
|
195 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
196 |
+
"""Returns a tensor of size batch_size with high values of risk.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
batch_size: number of elements in the out tensor
|
200 |
+
device: device of the output tensor
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
A tensor of shape (batchc_size,) filled with the highest possible risk-level
|
204 |
+
"""
|
205 |
+
|
206 |
+
|
207 |
+
class UniformRiskLevelSampler(AbstractRiskLevelSampler):
|
208 |
+
"""Risk-level sampler with a uniform distribution
|
209 |
+
|
210 |
+
Args:
|
211 |
+
min: minimum risk-level
|
212 |
+
max: maximum risk-level
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, min: int, max: int) -> None:
|
216 |
+
self.min = min
|
217 |
+
self.max = max
|
218 |
+
|
219 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
220 |
+
return torch.rand(batch_size, device=device) * (self.max - self.min) + self.min
|
221 |
+
|
222 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
223 |
+
return torch.ones(batch_size, device=device) * self.max
|
224 |
+
|
225 |
+
|
226 |
+
class NormalRiskLevelSampler(AbstractRiskLevelSampler):
|
227 |
+
"""Risk-level sampler with a normal distribution
|
228 |
+
|
229 |
+
Args:
|
230 |
+
mean: average risk-level
|
231 |
+
sigma: standard deviation of the sampler
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, mean: int, sigma: int) -> None:
|
235 |
+
self.mean = mean
|
236 |
+
self.sigma = sigma
|
237 |
+
|
238 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
239 |
+
return torch.randn(batch_size, device=device) * self.sigma + self.mean
|
240 |
+
|
241 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
242 |
+
return torch.ones(batch_size, device=device) * self.sigma * 3
|
243 |
+
|
244 |
+
|
245 |
+
class BernoulliRiskLevelSampler(AbstractRiskLevelSampler):
|
246 |
+
"""Risk-level sampler with a scaled Bernoulli distribution
|
247 |
+
|
248 |
+
Args:
|
249 |
+
min: minimum risk-level
|
250 |
+
max: maximum risk-level
|
251 |
+
p: Bernoulli parameter
|
252 |
+
"""
|
253 |
+
|
254 |
+
def __init__(self, min: int, max: int, p: int) -> None:
|
255 |
+
self.min = min
|
256 |
+
self.max = max
|
257 |
+
self.p = p
|
258 |
+
|
259 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
260 |
+
return (
|
261 |
+
torch.bernoulli(torch.ones(batch_size, device=device) * self.p)
|
262 |
+
* (self.max - self.min)
|
263 |
+
+ self.min
|
264 |
+
)
|
265 |
+
|
266 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
267 |
+
return torch.ones(batch_size, device=device) * self.max
|
268 |
+
|
269 |
+
|
270 |
+
class BetaRiskLevelSampler(AbstractRiskLevelSampler):
|
271 |
+
"""Risk-level sampler with a scaled Beta distribution
|
272 |
+
|
273 |
+
Distribution properties:
|
274 |
+
mean = alpha*(max-min)/(alpha + beta) + min
|
275 |
+
mode = (alpha-1)*(max-min)/(alpha + beta - 2) + min
|
276 |
+
variance = alpha*beta*(max-min)**2/((alpha+beta)**2*(alpha+beta+1))
|
277 |
+
|
278 |
+
Args:
|
279 |
+
min: minimum risk-level
|
280 |
+
max: maximum risk-level
|
281 |
+
alpha: First distribution parameter
|
282 |
+
beta: Second distribution parameter
|
283 |
+
"""
|
284 |
+
|
285 |
+
def __init__(self, min: int, max: int, alpha: float, beta: float) -> None:
|
286 |
+
self.min = min
|
287 |
+
self.max = max
|
288 |
+
self._distribution = torch.distributions.Beta(
|
289 |
+
torch.tensor([alpha], dtype=torch.float32),
|
290 |
+
torch.tensor([beta], dtype=torch.float32),
|
291 |
+
)
|
292 |
+
|
293 |
+
@property
|
294 |
+
def alpha(self):
|
295 |
+
return self._distribution.concentration1.item()
|
296 |
+
|
297 |
+
@alpha.setter
|
298 |
+
def alpha(self, alpha: float):
|
299 |
+
self._distribution = torch.distributions.Beta(
|
300 |
+
torch.tensor([alpha], dtype=torch.float32),
|
301 |
+
torch.tensor([self.beta], dtype=torch.float32),
|
302 |
+
)
|
303 |
+
|
304 |
+
@property
|
305 |
+
def beta(self):
|
306 |
+
return self._distribution.concentration0.item()
|
307 |
+
|
308 |
+
@beta.setter
|
309 |
+
def beta(self, beta: float):
|
310 |
+
self._distribution = torch.distributions.Beta(
|
311 |
+
torch.tensor([self.alpha], dtype=torch.float32),
|
312 |
+
torch.tensor([beta], dtype=torch.float32),
|
313 |
+
)
|
314 |
+
|
315 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
316 |
+
return (
|
317 |
+
self._distribution.sample((batch_size,)).to(device) * (self.max - self.min)
|
318 |
+
+ self.min
|
319 |
+
).view(batch_size)
|
320 |
+
|
321 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
322 |
+
return torch.ones(batch_size, device=device) * self.max
|
323 |
+
|
324 |
+
|
325 |
+
class Chi2RiskLevelSampler(AbstractRiskLevelSampler):
|
326 |
+
"""Risk-level sampler with a scaled chi2 distribution
|
327 |
+
|
328 |
+
Distribution properties:
|
329 |
+
mean = k*scale + min
|
330 |
+
mode = max(k-2, 0)*scale + min
|
331 |
+
variance = 2*k*scale**2
|
332 |
+
|
333 |
+
Args:
|
334 |
+
min: minimum risk-level
|
335 |
+
scale: scaling factor for the risk-level
|
336 |
+
k: Chi2 parameter: degrees of freedom of the distribution
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(self, min: int, scale: float, k: int) -> None:
|
340 |
+
self.min = min
|
341 |
+
self.scale = scale
|
342 |
+
self._distribution = torch.distributions.Chi2(
|
343 |
+
torch.tensor([k], dtype=torch.float32)
|
344 |
+
)
|
345 |
+
|
346 |
+
@property
|
347 |
+
def k(self):
|
348 |
+
return self._distribution.df.item()
|
349 |
+
|
350 |
+
@k.setter
|
351 |
+
def k(self, k: int):
|
352 |
+
self._distribution = torch.distributions.Chi2(
|
353 |
+
torch.tensor([k], dtype=torch.float32)
|
354 |
+
)
|
355 |
+
|
356 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
357 |
+
return (
|
358 |
+
self._distribution.sample((batch_size,)).to(device) * self.scale + self.min
|
359 |
+
).view(batch_size)
|
360 |
+
|
361 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
362 |
+
std = self.scale * math.sqrt(2 * self.k)
|
363 |
+
return torch.ones(batch_size, device=device) * std * 3
|
364 |
+
|
365 |
+
|
366 |
+
class LogNormalRiskLevelSampler(AbstractRiskLevelSampler):
|
367 |
+
"""Risk-level sampler with a scaled Beta distribution
|
368 |
+
|
369 |
+
Distribution properties:
|
370 |
+
mean = exp(mu + sigma**2/2)*scale + min
|
371 |
+
mode = exp(mu - sigma**2)*scale + min
|
372 |
+
variance = (exp(sigma**2)-1)*exp(2*mu+sigma**2)*scale**2
|
373 |
+
|
374 |
+
Args:
|
375 |
+
min: minimum risk-level
|
376 |
+
scale: scaling factor for the risk-level
|
377 |
+
mu: First distribution parameter
|
378 |
+
sigma: maximum risk-level
|
379 |
+
"""
|
380 |
+
|
381 |
+
def __init__(self, min: int, scale: float, mu: float, sigma: float) -> None:
|
382 |
+
self.min = min
|
383 |
+
self.scale = scale
|
384 |
+
self._distribution = torch.distributions.LogNormal(
|
385 |
+
torch.tensor([mu], dtype=torch.float32),
|
386 |
+
torch.tensor([sigma], dtype=torch.float32),
|
387 |
+
)
|
388 |
+
|
389 |
+
@property
|
390 |
+
def mu(self):
|
391 |
+
return self._distribution.loc.item()
|
392 |
+
|
393 |
+
@mu.setter
|
394 |
+
def mu(self, mu: float):
|
395 |
+
self._distribution = torch.distributions.LogNormal(
|
396 |
+
torch.tensor([mu], dtype=torch.float32),
|
397 |
+
torch.tensor([self.sigma], dtype=torch.float32),
|
398 |
+
)
|
399 |
+
|
400 |
+
@property
|
401 |
+
def sigma(self) -> float:
|
402 |
+
return self._distribution.scale.item()
|
403 |
+
|
404 |
+
@sigma.setter
|
405 |
+
def sigma(self, sigma: float):
|
406 |
+
self._distribution = torch.distributions.LogNormal(
|
407 |
+
torch.tensor([self.mu], dtype=torch.float32),
|
408 |
+
torch.tensor([sigma], dtype=torch.float32),
|
409 |
+
)
|
410 |
+
|
411 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
412 |
+
return (
|
413 |
+
self._distribution.sample((batch_size,)).to(device) * self.scale + self.min
|
414 |
+
).view(batch_size)
|
415 |
+
|
416 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
417 |
+
std = (
|
418 |
+
(torch.exp(self.sigma.square()) - 1).sqrt()
|
419 |
+
* torch.exp(self.mu + self.sigma.square() / 2)
|
420 |
+
* self.scale
|
421 |
+
)
|
422 |
+
return torch.ones(batch_size, device=device) * 3 * std
|
423 |
+
|
424 |
+
|
425 |
+
class LogUniformRiskLevelSampler(AbstractRiskLevelSampler):
|
426 |
+
"""Risk-level sampler with a reversed log-uniform distribution (increasing density function). Between min and max.
|
427 |
+
|
428 |
+
Distribution properties:
|
429 |
+
mean = (max - min)/ln((max+1)/(min+1)) - 1/scale
|
430 |
+
mode = None
|
431 |
+
variance = (((max+1)^2 - (min+1)^2)/(2*ln((max+1)/(min+1))) - ((max - min)/ln((max+1)/(min+1)))^2)
|
432 |
+
|
433 |
+
Args:
|
434 |
+
min: minimum risk-level
|
435 |
+
max: maximum risk-level
|
436 |
+
scale: scale to apply to the sampling before applying exponential,
|
437 |
+
the output is rescaled back to fit in bounds [min, max] (the higher the scale the less uniform the distribution)
|
438 |
+
"""
|
439 |
+
|
440 |
+
def __init__(self, min: float, max: float, scale: float) -> None:
|
441 |
+
assert min >= 0
|
442 |
+
assert max > min
|
443 |
+
assert scale > 0
|
444 |
+
self.min = min
|
445 |
+
self.max = max
|
446 |
+
self.scale = scale
|
447 |
+
|
448 |
+
def sample(self, batch_size: int, device: torch.device) -> Tensor:
|
449 |
+
scale = self.scale / (self.max - self.min)
|
450 |
+
max = self.max * scale
|
451 |
+
min = self.min * scale
|
452 |
+
return (
|
453 |
+
max
|
454 |
+
- (
|
455 |
+
(
|
456 |
+
torch.rand(batch_size, device=device)
|
457 |
+
* (math.log(max + 1) - math.log(min + 1))
|
458 |
+
+ math.log(min + 1)
|
459 |
+
).exp()
|
460 |
+
- 1
|
461 |
+
)
|
462 |
+
+ min
|
463 |
+
) / scale
|
464 |
+
|
465 |
+
def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
|
466 |
+
return torch.ones(batch_size, device=device) * self.max
|
467 |
+
|
468 |
+
|
469 |
+
def get_risk_level_sampler(distribution_params: dict) -> AbstractRiskLevelSampler:
|
470 |
+
"""Function that returns the risk level sampler that matches the given parameters.
|
471 |
+
Tries to give a comprehensive feedback if the parameters are not recognized and raise an error.
|
472 |
+
|
473 |
+
Args:
|
474 |
+
Risk distribution should be one of the following types (with different parameter values as desired) :
|
475 |
+
{"type": "uniform", "min": 0, "max": 1},
|
476 |
+
{"type": "normal", "mean": 0, "sigma": 1},
|
477 |
+
{"type": "bernoulli", "p": 0.5, "min": 0, "max": 1},
|
478 |
+
{"type": "beta", "alpha": 2, "beta": 5, "min": 0, "max": 1},
|
479 |
+
{"type": "chi2", "k": 3, "min": 0, "scale": 1},
|
480 |
+
{"type": "log-normal", "mu": 0, "sigma": 1, "min": 0, "scale": 1}
|
481 |
+
{"type": "log-uniform", "min": 0, "max": 1, "scale": 1}
|
482 |
+
|
483 |
+
Raises:
|
484 |
+
RuntimeError: If the given parameter dictionary does not match one of the expected formats, raise a comprehensive error.
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
A risk level sampler matching the given parameters.
|
488 |
+
"""
|
489 |
+
known_types = [
|
490 |
+
"uniform",
|
491 |
+
"normal",
|
492 |
+
"bernoulli",
|
493 |
+
"beta",
|
494 |
+
"chi2",
|
495 |
+
"log-normal",
|
496 |
+
"log-uniform",
|
497 |
+
]
|
498 |
+
try:
|
499 |
+
if distribution_params["type"].lower() == "uniform":
|
500 |
+
expected_params = inspect.getfullargspec(UniformRiskLevelSampler)[0][1:]
|
501 |
+
return UniformRiskLevelSampler(
|
502 |
+
distribution_params["min"], distribution_params["max"]
|
503 |
+
)
|
504 |
+
elif distribution_params["type"].lower() == "normal":
|
505 |
+
expected_params = inspect.getfullargspec(NormalRiskLevelSampler)[0][1:]
|
506 |
+
return NormalRiskLevelSampler(
|
507 |
+
distribution_params["mean"], distribution_params["sigma"]
|
508 |
+
)
|
509 |
+
elif distribution_params["type"].lower() == "bernoulli":
|
510 |
+
expected_params = inspect.getfullargspec(BernoulliRiskLevelSampler)[0][1:]
|
511 |
+
return BernoulliRiskLevelSampler(
|
512 |
+
distribution_params["min"],
|
513 |
+
distribution_params["max"],
|
514 |
+
distribution_params["p"],
|
515 |
+
)
|
516 |
+
elif distribution_params["type"].lower() == "beta":
|
517 |
+
expected_params = inspect.getfullargspec(BetaRiskLevelSampler)[0][1:]
|
518 |
+
return BetaRiskLevelSampler(
|
519 |
+
distribution_params["min"],
|
520 |
+
distribution_params["max"],
|
521 |
+
distribution_params["alpha"],
|
522 |
+
distribution_params["beta"],
|
523 |
+
)
|
524 |
+
elif distribution_params["type"].lower() == "chi2":
|
525 |
+
expected_params = inspect.getfullargspec(Chi2RiskLevelSampler)[0][1:]
|
526 |
+
return Chi2RiskLevelSampler(
|
527 |
+
distribution_params["min"],
|
528 |
+
distribution_params["scale"],
|
529 |
+
distribution_params["k"],
|
530 |
+
)
|
531 |
+
elif distribution_params["type"].lower() == "log-normal":
|
532 |
+
expected_params = inspect.getfullargspec(LogNormalRiskLevelSampler)[0][1:]
|
533 |
+
return LogNormalRiskLevelSampler(
|
534 |
+
distribution_params["min"],
|
535 |
+
distribution_params["scale"],
|
536 |
+
distribution_params["mu"],
|
537 |
+
distribution_params["sigma"],
|
538 |
+
)
|
539 |
+
elif distribution_params["type"].lower() == "log-uniform":
|
540 |
+
expected_params = inspect.getfullargspec(LogUniformRiskLevelSampler)[0][1:]
|
541 |
+
return LogUniformRiskLevelSampler(
|
542 |
+
distribution_params["min"],
|
543 |
+
distribution_params["max"],
|
544 |
+
distribution_params["scale"],
|
545 |
+
)
|
546 |
+
else:
|
547 |
+
raise RuntimeError(
|
548 |
+
f"Distribution {distribution_params['type']} is unknown. It should be one of {known_types}."
|
549 |
+
)
|
550 |
+
except KeyError:
|
551 |
+
if "type" in distribution_params:
|
552 |
+
raise RuntimeError(
|
553 |
+
f"The distribution '{distribution_params['type']}' is known but the given parameters {distribution_params} do not match the expected parameters {expected_params}."
|
554 |
+
)
|
555 |
+
else:
|
556 |
+
raise RuntimeError(
|
557 |
+
f"The given distribution parameters {distribution_params} do not define the distribution type in the field 'type'. Please add a field 'type' and set it to one of the handeled types: {known_types}."
|
558 |
+
)
|
559 |
+
|
560 |
+
|
561 |
+
if __name__ == "__main__":
|
562 |
+
import matplotlib.pyplot as plt
|
563 |
+
|
564 |
+
sampler = get_risk_level_sampler(
|
565 |
+
{"type": "log-uniform", "min": 0, "max": 1, "scale": 10}
|
566 |
+
)
|
567 |
+
# sampler = get_risk_level_sampler({"type": "normal", "mean": 0, "sigma": 1})
|
568 |
+
a = sampler.sample(10000, "cpu").detach().numpy()
|
569 |
+
_ = plt.hist(a, bins="auto") # arguments are passed to np.histogram
|
570 |
+
plt.title("Histogram with 'auto' bins")
|
571 |
+
plt.show()
|
risk_biased/utils/torch_utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
@torch.jit.script
|
8 |
+
def torch_linspace(start: Tensor, stop: Tensor, num: int) -> torch.Tensor:
|
9 |
+
"""
|
10 |
+
Copy-pasted from https://github.com/pytorch/pytorch/issues/61292
|
11 |
+
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
12 |
+
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
13 |
+
"""
|
14 |
+
# create a tensor of 'num' steps from 0 to 1
|
15 |
+
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
16 |
+
|
17 |
+
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
18 |
+
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
19 |
+
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
20 |
+
for i in range(start.ndim):
|
21 |
+
steps = steps.unsqueeze(-1)
|
22 |
+
|
23 |
+
# the output starts at 'start' and increments until 'stop' in each dimension
|
24 |
+
out = start[None] + steps * (stop - start)[None]
|
25 |
+
|
26 |
+
return out
|
27 |
+
|
28 |
+
|
29 |
+
def load_weights(
|
30 |
+
model: torch.nn.Module, checkpoint: dict, strict=True
|
31 |
+
) -> torch.nn.Module:
|
32 |
+
"""This function is used instead of the one provided by pytorch lightning
|
33 |
+
because for unexplained reasons, the pytorch lightning load function did
|
34 |
+
not behave as intended: loading several times from the same checkpoint
|
35 |
+
resulted in different loaded weight values...
|
36 |
+
|
37 |
+
Args:
|
38 |
+
model: a model in which new weights should be set
|
39 |
+
checkpoint: a loaded pytorch checkpoint (probably resulting from torch.load(filename))
|
40 |
+
strict: Default to True, wether to fail if
|
41 |
+
|
42 |
+
"""
|
43 |
+
if not strict:
|
44 |
+
model_dict = model.state_dict()
|
45 |
+
pretrained_dict = {
|
46 |
+
k: v for k, v in checkpoint["state_dict"].items() if k in model_dict
|
47 |
+
}
|
48 |
+
diff1 = checkpoint["state_dict"].keys() - model_dict.keys()
|
49 |
+
if diff1:
|
50 |
+
warnings.warn(
|
51 |
+
f"Found keys {diff1} in checkpoint without any match in the model, ignoring corresponding values."
|
52 |
+
)
|
53 |
+
diff2 = model_dict.keys() - checkpoint["state_dict"].keys()
|
54 |
+
if diff2:
|
55 |
+
warnings.warn(
|
56 |
+
f"Missing keys {diff2} from the checkpoint, the corresponding weights will keep their initial values."
|
57 |
+
)
|
58 |
+
pretrained_dict = {
|
59 |
+
k: v for k, v in checkpoint["state_dict"].items() if k in model_dict
|
60 |
+
}
|
61 |
+
model_dict.update(pretrained_dict)
|
62 |
+
else:
|
63 |
+
model_dict = checkpoint["state_dict"]
|
64 |
+
|
65 |
+
model.load_state_dict(model_dict, strict=strict)
|
66 |
+
return model
|
risk_biased/utils/waymo_dataloader.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List
|
2 |
+
from cv2 import repeat
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
import numpy as np
|
9 |
+
import pickle
|
10 |
+
import os
|
11 |
+
|
12 |
+
from mmcv import Config
|
13 |
+
|
14 |
+
|
15 |
+
class WaymoDataset(Dataset):
|
16 |
+
"""
|
17 |
+
Dataset loader for custom preprocessed files of Waymo data.
|
18 |
+
Args:
|
19 |
+
path: path to the dataset directory
|
20 |
+
args: global settings
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, cfg: Config, split: str, input_angle: bool = True):
|
24 |
+
super(WaymoDataset, self).__init__()
|
25 |
+
self.p_exchange_two_first = 1
|
26 |
+
if "val" in split.lower():
|
27 |
+
path = cfg.val_dataset_path
|
28 |
+
elif "test" in split.lower():
|
29 |
+
path = cfg.test_dataset_path
|
30 |
+
elif "sample" in split.lower():
|
31 |
+
path = cfg.sample_dataset_path
|
32 |
+
else:
|
33 |
+
path = cfg.train_dataset_path
|
34 |
+
self.p_exchange_two_first = cfg.p_exchange_two_first
|
35 |
+
|
36 |
+
self.file_list = [
|
37 |
+
os.path.join(path, name)
|
38 |
+
for name in os.listdir(path)
|
39 |
+
if os.path.isfile(os.path.join(path, name))
|
40 |
+
]
|
41 |
+
self.normalize = cfg.normalize_angle
|
42 |
+
# self.load_dataset(path, 16)
|
43 |
+
# self.idx_list = list(self.dataset.keys())
|
44 |
+
self.input_angle = input_angle
|
45 |
+
self.hist_len = cfg.num_steps
|
46 |
+
self.fut_len = cfg.num_steps_future
|
47 |
+
self.time_len = self.hist_len + self.fut_len
|
48 |
+
self.min_num_obs = cfg.min_num_observation
|
49 |
+
self.max_size_lane = cfg.max_size_lane
|
50 |
+
self.random_rotation = cfg.random_rotation
|
51 |
+
self.random_translation = cfg.random_translation
|
52 |
+
self.angle_std = cfg.angle_std
|
53 |
+
self.translation_distance_std = cfg.translation_distance_std
|
54 |
+
self.max_num_agents = cfg.max_num_agents
|
55 |
+
self.max_num_objects = cfg.max_num_objects
|
56 |
+
self.state_dim = cfg.state_dim
|
57 |
+
self.map_state_dim = cfg.map_state_dim
|
58 |
+
self.dt = cfg.dt
|
59 |
+
|
60 |
+
if "val" in os.path.basename(path).lower():
|
61 |
+
self.dataset_size_limit = cfg.val_dataset_size_limit
|
62 |
+
else:
|
63 |
+
self.dataset_size_limit = cfg.train_dataset_size_limit
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
if self.dataset_size_limit is not None:
|
67 |
+
return min(len(self.file_list), self.dataset_size_limit)
|
68 |
+
else:
|
69 |
+
return len(self.file_list)
|
70 |
+
|
71 |
+
def __getitem__(
|
72 |
+
self, idx: int
|
73 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
74 |
+
"""
|
75 |
+
Get the item at index idx in the dataset. Normalize the scene and output absolute angle and position.
|
76 |
+
Returns:
|
77 |
+
trajectories, mask, mask_loss, lanes, mask_lanes, angle, mean_position
|
78 |
+
"""
|
79 |
+
selected_file = self.file_list[idx]
|
80 |
+
with open(selected_file, "rb") as handle:
|
81 |
+
dataset = pickle.load(handle)
|
82 |
+
rel_state_all = dataset["traj"]
|
83 |
+
mask_all = dataset["mask_traj"]
|
84 |
+
mask_loss = dataset["mask_to_predict"]
|
85 |
+
rel_lane_all = dataset["lanes"]
|
86 |
+
mask_lane_all = dataset["mask_lanes"]
|
87 |
+
mean_pos = dataset["mean_pos"]
|
88 |
+
assert (
|
89 |
+
(
|
90 |
+
rel_state_all[self.hist_len + 5 :, :, :2][mask_all[self.hist_len + 5 :]]
|
91 |
+
!= 0
|
92 |
+
)
|
93 |
+
.any(-1)
|
94 |
+
.all()
|
95 |
+
)
|
96 |
+
assert (
|
97 |
+
(
|
98 |
+
rel_state_all[self.hist_len + 5 :, :, :2][
|
99 |
+
mask_loss[self.hist_len + 5 :]
|
100 |
+
]
|
101 |
+
!= 0
|
102 |
+
)
|
103 |
+
.any(-1)
|
104 |
+
.all()
|
105 |
+
)
|
106 |
+
if "lane_states" in dataset.keys():
|
107 |
+
lane_states = dataset["lane_states"]
|
108 |
+
else:
|
109 |
+
lane_states = None
|
110 |
+
if np.random.rand() > self.p_exchange_two_first:
|
111 |
+
rel_state_all[:, [0, 1]] = rel_state_all[:, [1, 0]]
|
112 |
+
mask_all[:, [0, 1]] = mask_all[:, [1, 0]]
|
113 |
+
mask_loss[:, [0, 1]] = mask_loss[:, [1, 0]]
|
114 |
+
assert (
|
115 |
+
(
|
116 |
+
rel_state_all[self.hist_len + 5 :, :, :2][mask_all[self.hist_len + 5 :]]
|
117 |
+
!= 0
|
118 |
+
)
|
119 |
+
.any(-1)
|
120 |
+
.all()
|
121 |
+
)
|
122 |
+
assert (
|
123 |
+
(
|
124 |
+
rel_state_all[self.hist_len + 5 :, :, :2][
|
125 |
+
mask_loss[self.hist_len + 5 :]
|
126 |
+
]
|
127 |
+
!= 0
|
128 |
+
)
|
129 |
+
.any(-1)
|
130 |
+
.all()
|
131 |
+
)
|
132 |
+
if self.normalize:
|
133 |
+
angle = rel_state_all[self.hist_len - 1, 1, 2]
|
134 |
+
|
135 |
+
if self.random_rotation:
|
136 |
+
if self.normalize:
|
137 |
+
angle += np.random.normal(0, self.angle_std)
|
138 |
+
else:
|
139 |
+
angle += np.random.uniform(-np.pi, np.pi)
|
140 |
+
if self.random_translation:
|
141 |
+
distance = (
|
142 |
+
np.random.normal([0, 0], self.translation_distance_std, 2)
|
143 |
+
* mask_all[self.hist_len - 1 : self.hist_len, :, None]
|
144 |
+
- rel_state_all[self.hist_len - 1 : self.hist_len, 1:2, :2]
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
distance = -rel_state_all[self.hist_len - 1 : self.hist_len, 1:2, :2]
|
148 |
+
|
149 |
+
rel_state_all[:, :, :2] += distance
|
150 |
+
rel_lane_all[:, :, :2] += distance
|
151 |
+
mean_pos += distance[0, 0, :]
|
152 |
+
rel_state_all = self.scene_rotation(rel_state_all, -angle)
|
153 |
+
rel_lane_all = self.scene_rotation(rel_lane_all, -angle)
|
154 |
+
|
155 |
+
else:
|
156 |
+
if self.random_translation:
|
157 |
+
distance = np.random.normal([0, 0], self.translation_distance_std, 2)
|
158 |
+
rel_state_all = (
|
159 |
+
rel_state_all
|
160 |
+
+ mask_all[self.hist_len - 1 : self.hist_len, :, None] * distance
|
161 |
+
)
|
162 |
+
rel_lane_all = (
|
163 |
+
rel_lane_all
|
164 |
+
+ mask_all[self.hist_len - 1 : self.hist_len, :, None] * distance
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.random_rotation:
|
168 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
169 |
+
rel_state_all = self.scene_rotation(rel_state_all, angle)
|
170 |
+
rel_lane_all = self.scene_rotation(rel_lane_all, angle)
|
171 |
+
else:
|
172 |
+
angle = 0
|
173 |
+
return (
|
174 |
+
rel_state_all,
|
175 |
+
mask_all,
|
176 |
+
mask_loss,
|
177 |
+
rel_lane_all,
|
178 |
+
mask_lane_all,
|
179 |
+
lane_states,
|
180 |
+
angle,
|
181 |
+
mean_pos,
|
182 |
+
idx,
|
183 |
+
)
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def scene_rotation(coor: np.ndarray, angle: float) -> np.ndarray:
|
187 |
+
"""
|
188 |
+
Rotate all the coordinates with the same angle
|
189 |
+
Args:
|
190 |
+
coor: array of x, y coordinates
|
191 |
+
angle: radiants to rotate the coordinates by
|
192 |
+
Returns:
|
193 |
+
coor_rotated
|
194 |
+
"""
|
195 |
+
rot_matrix = np.zeros((2, 2))
|
196 |
+
c = np.cos(angle)
|
197 |
+
s = np.sin(angle)
|
198 |
+
rot_matrix[0, 0] = c
|
199 |
+
rot_matrix[0, 1] = -s
|
200 |
+
rot_matrix[1, 0] = s
|
201 |
+
rot_matrix[1, 1] = c
|
202 |
+
coor[..., :2] = np.matmul(
|
203 |
+
rot_matrix, np.expand_dims(coor[..., :2], axis=-1)
|
204 |
+
).squeeze(-1)
|
205 |
+
if coor.shape[-1] > 2:
|
206 |
+
coor[..., 2] += angle
|
207 |
+
if coor.shape[-1] >= 5:
|
208 |
+
coor[..., 3:5] = np.matmul(
|
209 |
+
rot_matrix, np.expand_dims(coor[..., 3:5], axis=-1)
|
210 |
+
).squeeze(-1)
|
211 |
+
return coor
|
212 |
+
|
213 |
+
def fill_past(self, past, mask_past):
|
214 |
+
current_velocity = past[..., 0, 3:5]
|
215 |
+
for t in range(1, past.shape[-2]):
|
216 |
+
current_velocity = torch.where(
|
217 |
+
mask_past[..., t, None], past[..., t, 3:5], current_velocity
|
218 |
+
)
|
219 |
+
past[..., t, 3:5] = current_velocity
|
220 |
+
predicted_position = past[..., t - 1, :2] + current_velocity * self.dt
|
221 |
+
past[..., t, :2] = torch.where(
|
222 |
+
mask_past[..., t, None], past[..., t, :2], predicted_position
|
223 |
+
)
|
224 |
+
return past
|
225 |
+
|
226 |
+
def collate_fn(
|
227 |
+
self, samples: List
|
228 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
229 |
+
"""
|
230 |
+
Assemble trajectories into batches with 0-padding.
|
231 |
+
Args:
|
232 |
+
samples: list of sampled trajectories (list of outputs of __getitem__)
|
233 |
+
Returns:
|
234 |
+
(starred dimensions have different values from one batch to the next but the ones with the same name are consistent within the batch)
|
235 |
+
batch : ((batch_size, num_agents*, num_steps, state_dim), # past trajectories of all agents in the scene
|
236 |
+
(batch_size, num_agents*, num_steps), # mask past False where past trajectories are padding data
|
237 |
+
(batch_size, num_agents*, num_steps_future, state_dim), # future trajectory
|
238 |
+
(batch_size, num_agents*, num_steps_future), # mask future False where future trajectories are padding data
|
239 |
+
(batch_size, num_agents*, num_steps_future), # mask loss False where future trajectories are not to be predicted
|
240 |
+
(batch_size, num_objects*, object_seq_len*, map_state_dim),# map object sequences in the scene
|
241 |
+
(batch_size, num_objects*, object_seq_len*), # mask map False where map objects are padding data
|
242 |
+
(batch_size, num_agents*, state_dim), # position offset of all agents relative to ego at present time
|
243 |
+
(batch_size, num_steps, state_dim), # ego past trajectory
|
244 |
+
(batch_size, num_steps_future, state_dim)) # ego future trajectory
|
245 |
+
"""
|
246 |
+
max_n_vehicle = 50
|
247 |
+
max_n_lanes = 0
|
248 |
+
for (
|
249 |
+
coor,
|
250 |
+
mask,
|
251 |
+
mask_loss,
|
252 |
+
lanes,
|
253 |
+
mask_lanes,
|
254 |
+
lane_states,
|
255 |
+
mean_angle,
|
256 |
+
mean_pos,
|
257 |
+
idx,
|
258 |
+
) in samples:
|
259 |
+
# time_len_coor = self._count_last_obs(coor, hist_len)
|
260 |
+
# num_vehicle = np.sum(time_len_coor > self.min_num_obs)
|
261 |
+
num_vehicle = coor.shape[1]
|
262 |
+
num_lanes = lanes.shape[1]
|
263 |
+
max_n_vehicle = max(num_vehicle, max_n_vehicle)
|
264 |
+
max_n_lanes = max(num_lanes, max_n_lanes)
|
265 |
+
if max_n_vehicle <= 0:
|
266 |
+
raise RuntimeError
|
267 |
+
data_batch = np.zeros(
|
268 |
+
[self.time_len, len(samples), max_n_vehicle, self.state_dim]
|
269 |
+
)
|
270 |
+
mask_batch = np.zeros([self.time_len, len(samples), max_n_vehicle])
|
271 |
+
mask_loss_batch = np.zeros([self.time_len, len(samples), max_n_vehicle])
|
272 |
+
lane_batch = np.zeros(
|
273 |
+
[self.max_size_lane, len(samples), max_n_lanes, self.map_state_dim]
|
274 |
+
)
|
275 |
+
mask_lane_batch = np.zeros([self.max_size_lane, len(samples), max_n_lanes])
|
276 |
+
mean_angle_batch = np.zeros([len(samples)])
|
277 |
+
mean_pos_batch = np.zeros([len(samples), 2])
|
278 |
+
tag_list = np.zeros([len(samples)])
|
279 |
+
idx_list = [0 for _ in range(len(samples))]
|
280 |
+
|
281 |
+
for sample_ind, (
|
282 |
+
coor,
|
283 |
+
mask,
|
284 |
+
mask_loss,
|
285 |
+
lanes,
|
286 |
+
mask_lanes,
|
287 |
+
lane_states,
|
288 |
+
mean_angle,
|
289 |
+
mean_pos,
|
290 |
+
idx,
|
291 |
+
) in enumerate(samples):
|
292 |
+
data_batch[:, sample_ind, : coor.shape[1], :] = coor[: self.time_len, :, :]
|
293 |
+
mask_batch[:, sample_ind, : mask.shape[1]] = mask[: self.time_len, :]
|
294 |
+
mask_loss_batch[:, sample_ind, : mask.shape[1]] = mask_loss[
|
295 |
+
: self.time_len, :
|
296 |
+
]
|
297 |
+
lane_batch[: lanes.shape[0], sample_ind, : lanes.shape[1], :2] = lanes
|
298 |
+
if lane_states is not None:
|
299 |
+
lane_states = repeat(
|
300 |
+
lane_states[:, : self.hist_len],
|
301 |
+
"objects time features -> one objects (time features)",
|
302 |
+
one=1,
|
303 |
+
)
|
304 |
+
lane_batch[
|
305 |
+
: lanes.shape[0], sample_ind, : lanes.shape[1], 2:
|
306 |
+
] = lane_states
|
307 |
+
mask_lane_batch[
|
308 |
+
: mask_lanes.shape[0], sample_ind, : mask_lanes.shape[1]
|
309 |
+
] = mask_lanes
|
310 |
+
mean_angle_batch[sample_ind] = mean_angle
|
311 |
+
mean_pos_batch[sample_ind, :] = mean_pos
|
312 |
+
# tag_list[sample_ind] = self.dataset[idx]["tag"]
|
313 |
+
idx_list[sample_ind] = idx
|
314 |
+
|
315 |
+
data_batch = torch.from_numpy(data_batch.astype("float32"))
|
316 |
+
mask_batch = torch.from_numpy(mask_batch.astype("bool"))
|
317 |
+
lane_batch = torch.from_numpy(lane_batch.astype("float32"))
|
318 |
+
mask_lane_batch = torch.from_numpy(mask_lane_batch.astype("bool"))
|
319 |
+
mean_pos_batch = torch.from_numpy(mean_pos_batch.astype("float32"))
|
320 |
+
mask_loss_batch = torch.from_numpy(mask_loss_batch.astype("bool"))
|
321 |
+
|
322 |
+
data_batch = rearrange(
|
323 |
+
data_batch, "time batch agents features -> batch agents time features"
|
324 |
+
)
|
325 |
+
mask_batch = rearrange(mask_batch, "time batch agents -> batch agents time")
|
326 |
+
mask_loss_batch = rearrange(
|
327 |
+
mask_loss_batch, "time batch agents -> batch agents time"
|
328 |
+
)
|
329 |
+
lane_batch = rearrange(
|
330 |
+
lane_batch,
|
331 |
+
"object_seq_len batch objects features-> batch objects object_seq_len features",
|
332 |
+
)
|
333 |
+
mask_lane_batch = rearrange(
|
334 |
+
mask_lane_batch,
|
335 |
+
"object_seq_len batch objects -> batch objects object_seq_len",
|
336 |
+
)
|
337 |
+
|
338 |
+
# The two first agents are the ones interacting, others are sorted by distance from the first agent
|
339 |
+
# Objects are also sorted by distance from the first agent
|
340 |
+
# Therefore, the limits in number, max_num_agents and max_num_objects can be seen as adaptative distance limits.
|
341 |
+
|
342 |
+
if not self.input_angle:
|
343 |
+
data_batch = torch.cat((data_batch[..., :2], data_batch[..., 3:]), dim=-1)
|
344 |
+
traj_past = data_batch[:, : self.max_num_agents, : self.hist_len, :]
|
345 |
+
mask_past = mask_batch[:, : self.max_num_agents, : self.hist_len]
|
346 |
+
traj_fut = data_batch[
|
347 |
+
:, : self.max_num_agents, self.hist_len : self.hist_len + self.fut_len, :
|
348 |
+
]
|
349 |
+
mask_fut = mask_batch[
|
350 |
+
:, : self.max_num_agents, self.hist_len : self.hist_len + self.fut_len
|
351 |
+
]
|
352 |
+
ego_past = data_batch[:, 0:1, : self.hist_len, :]
|
353 |
+
ego_fut = data_batch[:, 0:1, self.hist_len : self.hist_len + self.fut_len, :]
|
354 |
+
|
355 |
+
lane_batch = lane_batch[:, : self.max_num_objects]
|
356 |
+
mask_lane_batch = mask_lane_batch[:, : self.max_num_objects]
|
357 |
+
|
358 |
+
# Define what to predict (could be from Waymo's label of what to predict or the other agent that interacts with the ego...)
|
359 |
+
mask_loss_batch = torch.logical_and(
|
360 |
+
mask_loss_batch[
|
361 |
+
:, : self.max_num_agents, self.hist_len : self.hist_len + self.fut_len
|
362 |
+
],
|
363 |
+
mask_past.any(-1, keepdim=True),
|
364 |
+
)
|
365 |
+
# Remove all other agents so the model should only predict the first one
|
366 |
+
mask_loss_batch[:, 0] = False
|
367 |
+
mask_loss_batch[:, 2:] = False
|
368 |
+
|
369 |
+
# Normalize...
|
370 |
+
# traj_past = self.fill_past(traj_past, mask_past)
|
371 |
+
dynamic_state_size = 5 if self.input_angle else 4
|
372 |
+
offset_batch = traj_past[..., -1, :dynamic_state_size].clone()
|
373 |
+
traj_past[..., :dynamic_state_size] = traj_past[
|
374 |
+
..., :dynamic_state_size
|
375 |
+
] - offset_batch.unsqueeze(-2)
|
376 |
+
traj_fut[..., :dynamic_state_size] = traj_fut[
|
377 |
+
..., :dynamic_state_size
|
378 |
+
] - offset_batch.unsqueeze(-2)
|
379 |
+
|
380 |
+
return (
|
381 |
+
traj_past,
|
382 |
+
mask_past,
|
383 |
+
traj_fut,
|
384 |
+
mask_fut,
|
385 |
+
mask_loss_batch,
|
386 |
+
lane_batch,
|
387 |
+
mask_lane_batch,
|
388 |
+
offset_batch,
|
389 |
+
ego_past,
|
390 |
+
ego_fut,
|
391 |
+
)
|
392 |
+
|
393 |
+
|
394 |
+
class WaymoDataloaders:
|
395 |
+
def __init__(self, cfg: Config) -> None:
|
396 |
+
self.cfg = cfg
|
397 |
+
|
398 |
+
def sample_dataloader(self) -> DataLoader:
|
399 |
+
"""Setup and return sample DataLoader
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
DataLoader: sample DataLoader
|
403 |
+
"""
|
404 |
+
dataset = WaymoDataset(self.cfg, "sample")
|
405 |
+
sample_loader = DataLoader(
|
406 |
+
dataset=dataset,
|
407 |
+
batch_size=self.cfg.batch_size,
|
408 |
+
shuffle=False,
|
409 |
+
num_workers=self.cfg.num_workers,
|
410 |
+
collate_fn=dataset.collate_fn,
|
411 |
+
drop_last=True,
|
412 |
+
)
|
413 |
+
return sample_loader
|
414 |
+
|
415 |
+
def val_dataloader(
|
416 |
+
self, drop_last=True, shuffle=False, input_angle=True
|
417 |
+
) -> DataLoader:
|
418 |
+
"""Setup and return validation DataLoader
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
DataLoader: validation DataLoader
|
422 |
+
"""
|
423 |
+
dataset = WaymoDataset(self.cfg, "val", input_angle)
|
424 |
+
val_loader = DataLoader(
|
425 |
+
dataset=dataset,
|
426 |
+
batch_size=self.cfg.batch_size,
|
427 |
+
shuffle=shuffle,
|
428 |
+
num_workers=self.cfg.num_workers,
|
429 |
+
collate_fn=dataset.collate_fn,
|
430 |
+
drop_last=drop_last,
|
431 |
+
)
|
432 |
+
torch.cuda.empty_cache()
|
433 |
+
return val_loader
|
434 |
+
|
435 |
+
def train_dataloader(
|
436 |
+
self, drop_last=True, shuffle=True, input_angle=True
|
437 |
+
) -> DataLoader:
|
438 |
+
"""Setup and return training DataLoader
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
DataLoader: training DataLoader
|
442 |
+
"""
|
443 |
+
dataset = WaymoDataset(self.cfg, "train", input_angle)
|
444 |
+
train_loader = DataLoader(
|
445 |
+
dataset=dataset,
|
446 |
+
batch_size=self.cfg.batch_size,
|
447 |
+
shuffle=shuffle,
|
448 |
+
num_workers=self.cfg.num_workers,
|
449 |
+
collate_fn=dataset.collate_fn,
|
450 |
+
drop_last=drop_last,
|
451 |
+
)
|
452 |
+
torch.cuda.empty_cache()
|
453 |
+
return train_loader
|
454 |
+
|
455 |
+
def test_dataloader(self) -> DataLoader:
|
456 |
+
"""Setup and return test DataLoader
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
DataLoader: test DataLoader
|
460 |
+
"""
|
461 |
+
raise NotImplementedError("The waymo dataloader cannot load test samples yet.")
|
462 |
+
|
463 |
+
@staticmethod
|
464 |
+
def unnormalize_trajectory(
|
465 |
+
input: torch.Tensor, offset: torch.Tensor
|
466 |
+
) -> torch.Tensor:
|
467 |
+
"""Unnormalize trajectory by adding offset to input
|
468 |
+
|
469 |
+
Args:
|
470 |
+
input : (..., (n_sample), num_steps_future, state_dim) tensor of future
|
471 |
+
trajectory y
|
472 |
+
offset : (..., state_dim) tensor of offset to add to y
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
Unnormalized trajectory that has the same size as input
|
476 |
+
"""
|
477 |
+
assert offset.ndim == 3
|
478 |
+
batch_size, num_agents = offset.shape[:2]
|
479 |
+
offset_state_dim = offset.shape[-1]
|
480 |
+
assert offset_state_dim <= input.shape[-1]
|
481 |
+
assert input.shape[0] == batch_size
|
482 |
+
assert input.shape[1] == num_agents
|
483 |
+
input_copy = input.clone()
|
484 |
+
|
485 |
+
input_copy[..., :offset_state_dim] = input_copy[
|
486 |
+
..., :offset_state_dim
|
487 |
+
] + offset[..., : input.shape[-1]].reshape(
|
488 |
+
[batch_size, num_agents, *([1] * (input.ndim - 3)), offset_state_dim]
|
489 |
+
)
|
490 |
+
return input_copy
|
scripts/eval_scripts/compute_stats.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import pickle
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from inspect import signature
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from einops import repeat
|
12 |
+
import fire
|
13 |
+
import numpy as np
|
14 |
+
from pytorch_lightning.utilities.seed import seed_everything
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from risk_biased.utils.config_argparse import config_argparse
|
18 |
+
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams, get_cost
|
19 |
+
from risk_biased.utils.risk import get_risk_estimator
|
20 |
+
|
21 |
+
from risk_biased.utils.load_model import load_from_config
|
22 |
+
|
23 |
+
|
24 |
+
def to_device(batch, device):
|
25 |
+
output = []
|
26 |
+
for item in batch:
|
27 |
+
output.append(item.to(device))
|
28 |
+
return output
|
29 |
+
|
30 |
+
|
31 |
+
class CPU_Unpickler(pickle.Unpickler):
|
32 |
+
def find_class(self, module, name):
|
33 |
+
if module == "torch.storage" and name == "_load_from_bytes":
|
34 |
+
return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
|
35 |
+
else:
|
36 |
+
return super().find_class(module, name)
|
37 |
+
|
38 |
+
|
39 |
+
def distance(pred, truth):
|
40 |
+
"""
|
41 |
+
pred (Tensor): (..., time, xy)
|
42 |
+
truth (Tensor): (..., time, xy)
|
43 |
+
mask_loss (Tensor): (..., time) Defaults to None.
|
44 |
+
"""
|
45 |
+
return torch.sqrt(torch.sum(torch.square(pred[..., :2] - truth[..., :2]), -1))
|
46 |
+
|
47 |
+
|
48 |
+
def compute_metrics(
|
49 |
+
predictor,
|
50 |
+
batch,
|
51 |
+
cost,
|
52 |
+
risk_levels,
|
53 |
+
risk_estimator,
|
54 |
+
dt,
|
55 |
+
unnormalizer,
|
56 |
+
n_samples_risk,
|
57 |
+
n_samples_stats,
|
58 |
+
):
|
59 |
+
|
60 |
+
# risk_unbiased
|
61 |
+
# risk_biased
|
62 |
+
# cost
|
63 |
+
# FDE: unbiased, biased(risk_level=[0, 0.3, 0.5, 0.8, 1]) (for all samples so minFDE can be computed later)
|
64 |
+
# ADE (for all samples so minADE can be computed later)
|
65 |
+
|
66 |
+
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
|
67 |
+
mask_z = mask_x.any(-1)
|
68 |
+
|
69 |
+
_, z_mean_inference, z_log_std_inference = predictor.model(
|
70 |
+
x,
|
71 |
+
mask_x,
|
72 |
+
map,
|
73 |
+
mask_map,
|
74 |
+
offset=offset,
|
75 |
+
x_ego=x_ego,
|
76 |
+
y_ego=y_ego,
|
77 |
+
risk_level=None,
|
78 |
+
)
|
79 |
+
|
80 |
+
latent_distribs = {
|
81 |
+
"inference": {
|
82 |
+
"mean": z_mean_inference[:, 1].detach().cpu(),
|
83 |
+
"log_std": z_log_std_inference[:, 1].detach().cpu(),
|
84 |
+
}
|
85 |
+
}
|
86 |
+
inference_distances = []
|
87 |
+
cost_list = []
|
88 |
+
# Cut the number of samples in packs to avoid out-of-memory problems
|
89 |
+
# Compute and store cost for all packs
|
90 |
+
for _ in range(n_samples_risk // n_samples_stats):
|
91 |
+
z_samples_inference = predictor.model.inference_encoder.sample(
|
92 |
+
z_mean_inference,
|
93 |
+
z_log_std_inference,
|
94 |
+
n_samples=n_samples_stats,
|
95 |
+
)
|
96 |
+
|
97 |
+
y_samples = predictor.model.decode(
|
98 |
+
z_samples=z_samples_inference,
|
99 |
+
mask_z=mask_z,
|
100 |
+
x=x,
|
101 |
+
mask_x=mask_x,
|
102 |
+
map=map,
|
103 |
+
mask_map=mask_map,
|
104 |
+
offset=offset,
|
105 |
+
)
|
106 |
+
|
107 |
+
mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=n_samples_stats)
|
108 |
+
# Computing unbiased cost
|
109 |
+
cost_list.append(
|
110 |
+
get_cost(
|
111 |
+
cost,
|
112 |
+
x,
|
113 |
+
y_samples,
|
114 |
+
offset,
|
115 |
+
x_ego,
|
116 |
+
y_ego,
|
117 |
+
dt,
|
118 |
+
unnormalizer,
|
119 |
+
mask_loss_samples,
|
120 |
+
)[:, 1:2]
|
121 |
+
)
|
122 |
+
inference_distances.append(distance(y_samples, y.unsqueeze(2))[:, 1:2])
|
123 |
+
cost_dic = {}
|
124 |
+
cost_dic["inference"] = torch.cat(cost_list, 2).detach().cpu()
|
125 |
+
distance_dic = {}
|
126 |
+
distance_dic["inference"] = torch.cat(inference_distances, 2).detach().cpu()
|
127 |
+
|
128 |
+
# Set up the output risk tensor
|
129 |
+
risk_dic = {}
|
130 |
+
|
131 |
+
# Loop on risk_level values to fill the risk estimation for each value and compute stats at each risk level
|
132 |
+
for rl in risk_levels:
|
133 |
+
risk_level = (
|
134 |
+
torch.ones(
|
135 |
+
(x.shape[0], x.shape[1]),
|
136 |
+
device=x.device,
|
137 |
+
)
|
138 |
+
* rl
|
139 |
+
)
|
140 |
+
risk_dic[f"biased_{rl}"] = risk_estimator(
|
141 |
+
risk_level[:, 1:2].detach().cpu(), cost_dic["inference"]
|
142 |
+
)
|
143 |
+
|
144 |
+
y_samples_biased, z_mean_biased, z_log_std_biased = predictor.model(
|
145 |
+
x,
|
146 |
+
mask_x,
|
147 |
+
map,
|
148 |
+
mask_map,
|
149 |
+
offset=offset,
|
150 |
+
x_ego=x_ego,
|
151 |
+
y_ego=y_ego,
|
152 |
+
risk_level=risk_level,
|
153 |
+
n_samples=n_samples_stats,
|
154 |
+
)
|
155 |
+
latent_distribs[f"biased_{rl}"] = {
|
156 |
+
"mean": z_mean_biased[:, 1].detach().cpu(),
|
157 |
+
"log_std": z_log_std_biased[:, 1].detach().cpu(),
|
158 |
+
}
|
159 |
+
|
160 |
+
distance_dic[f"biased_{rl}"] = (
|
161 |
+
distance(y_samples_biased, y.unsqueeze(2))[:, 1].detach().cpu()
|
162 |
+
)
|
163 |
+
cost_dic[f"biased_{rl}"] = (
|
164 |
+
get_cost(
|
165 |
+
cost,
|
166 |
+
x,
|
167 |
+
y_samples_biased,
|
168 |
+
offset,
|
169 |
+
x_ego,
|
170 |
+
y_ego,
|
171 |
+
dt,
|
172 |
+
unnormalizer,
|
173 |
+
mask_loss_samples,
|
174 |
+
)[:, 1]
|
175 |
+
.detach()
|
176 |
+
.cpu()
|
177 |
+
)
|
178 |
+
|
179 |
+
# Return risks for the batch and all risk values
|
180 |
+
return {
|
181 |
+
"risk": risk_dic,
|
182 |
+
"cost": cost_dic,
|
183 |
+
"distance": distance_dic,
|
184 |
+
"latent_distribs": latent_distribs,
|
185 |
+
"mask": mask_loss[:, 1].detach().cpu(),
|
186 |
+
}
|
187 |
+
|
188 |
+
|
189 |
+
def cat_metrics_rec(metrics1, metrics2, cat_to):
|
190 |
+
for key in metrics1.keys():
|
191 |
+
if key not in metrics2.keys():
|
192 |
+
raise RuntimeError(
|
193 |
+
f"Trying to concatenate objects with different keys: {key} is not in second argument keys."
|
194 |
+
)
|
195 |
+
elif isinstance(metrics1[key], dict):
|
196 |
+
if key not in cat_to.keys():
|
197 |
+
cat_to[key] = {}
|
198 |
+
cat_metrics_rec(metrics1[key], metrics2[key], cat_to[key])
|
199 |
+
elif isinstance(metrics1[key], torch.Tensor):
|
200 |
+
cat_to[key] = torch.cat((metrics1[key], metrics2[key]), 0)
|
201 |
+
|
202 |
+
|
203 |
+
def cat_metrics(metrics1, metrics2):
|
204 |
+
out = {}
|
205 |
+
cat_metrics_rec(metrics1, metrics2, out)
|
206 |
+
return out
|
207 |
+
|
208 |
+
|
209 |
+
def masked_mean_std_ste(data, mask):
|
210 |
+
mask = mask.view(data.shape)
|
211 |
+
norm = mask.sum().clamp_min(1)
|
212 |
+
mean = (data * mask).sum() / norm
|
213 |
+
std = torch.sqrt(((data - mean) * mask).square().sum() / norm)
|
214 |
+
return mean.item(), std.item(), (std / torch.sqrt(norm)).item()
|
215 |
+
|
216 |
+
|
217 |
+
def masked_mean_range(data, mask):
|
218 |
+
data = data[mask]
|
219 |
+
mean = data.mean()
|
220 |
+
min = torch.quantile(data, 0.05)
|
221 |
+
max = torch.quantile(data, 0.95)
|
222 |
+
return mean, min, max
|
223 |
+
|
224 |
+
|
225 |
+
def masked_mean_dim(data, mask, dim):
|
226 |
+
norm = mask.sum(dim).clamp_min(1)
|
227 |
+
mean = (data * mask).sum(dim) / norm
|
228 |
+
return mean
|
229 |
+
|
230 |
+
|
231 |
+
def plot_risk_error(metrics, risk_levels, risk_estimator, max_n_samples, path_save):
|
232 |
+
cost_inference = metrics["cost"]["inference"]
|
233 |
+
cost_biased_0 = metrics["cost"]["biased_0"]
|
234 |
+
mask = metrics["mask"].any(1)
|
235 |
+
ones_tensor = torch.ones(mask.shape[0])
|
236 |
+
n_samples = np.minimum(cost_biased_0.shape[1], max_n_samples)
|
237 |
+
|
238 |
+
for rl in risk_levels:
|
239 |
+
key = f"biased_{rl}"
|
240 |
+
reference_risk = metrics["risk"][key]
|
241 |
+
mean_inference_risk_error_per_samples = np.zeros(n_samples - 1)
|
242 |
+
min_inference_risk_error_per_samples = np.zeros(n_samples - 1)
|
243 |
+
max_inference_risk_error_per_samples = np.zeros(n_samples - 1)
|
244 |
+
# mean_biased_0_risk_error_per_samples = np.zeros(n_samples-1)
|
245 |
+
# min_biased_0_risk_error_per_samples = np.zeros(n_samples-1)
|
246 |
+
# max_biased_0_risk_error_per_samples = np.zeros(n_samples-1)
|
247 |
+
mean_biased_risk_error_per_samples = np.zeros(n_samples - 1)
|
248 |
+
min_biased_risk_error_per_samples = np.zeros(n_samples - 1)
|
249 |
+
max_biased_risk_error_per_samples = np.zeros(n_samples - 1)
|
250 |
+
risk_level_tensor = ones_tensor * rl
|
251 |
+
for sub_samples in range(1, n_samples):
|
252 |
+
perm = torch.randperm(metrics["cost"][key].shape[1])[:sub_samples]
|
253 |
+
risk_error_biased = metrics["cost"][key][:, perm].mean(1) - reference_risk
|
254 |
+
(
|
255 |
+
mean_biased_risk_error_per_samples[sub_samples - 1],
|
256 |
+
min_biased_risk_error_per_samples[sub_samples - 1],
|
257 |
+
max_biased_risk_error_per_samples[sub_samples - 1],
|
258 |
+
) = masked_mean_range(risk_error_biased, mask)
|
259 |
+
risk_error_inference = (
|
260 |
+
risk_estimator(risk_level_tensor, cost_inference[:, :, :sub_samples])
|
261 |
+
- reference_risk
|
262 |
+
)
|
263 |
+
(
|
264 |
+
mean_inference_risk_error_per_samples[sub_samples - 1],
|
265 |
+
min_inference_risk_error_per_samples[sub_samples - 1],
|
266 |
+
max_inference_risk_error_per_samples[sub_samples - 1],
|
267 |
+
) = masked_mean_range(risk_error_inference, mask)
|
268 |
+
# risk_error_biased_0 = risk_estimator(risk_level_tensor, cost_biased_0[:, :sub_samples]) - reference_risk
|
269 |
+
# (mean_biased_0_risk_error_per_samples[sub_samples-1], min_biased_0_risk_error_per_samples[sub_samples-1], max_biased_0_risk_error_per_samples[sub_samples-1]) = masked_mean_range(risk_error_biased_0, mask)
|
270 |
+
|
271 |
+
plt.plot(
|
272 |
+
range(1, n_samples),
|
273 |
+
mean_inference_risk_error_per_samples,
|
274 |
+
label="Inference",
|
275 |
+
)
|
276 |
+
plt.fill_between(
|
277 |
+
range(1, n_samples),
|
278 |
+
min_inference_risk_error_per_samples,
|
279 |
+
max_inference_risk_error_per_samples,
|
280 |
+
alpha=0.3,
|
281 |
+
)
|
282 |
+
|
283 |
+
# plt.plot(range(1, n_samples), mean_biased_0_risk_error_per_samples, label="Unbiased")
|
284 |
+
# plt.fill_between(range(1, n_samples), min_biased_0_risk_error_per_samples, max_biased_0_risk_error_per_samples, alpha=.3)
|
285 |
+
|
286 |
+
plt.plot(
|
287 |
+
range(1, n_samples), mean_biased_risk_error_per_samples, label="Biased"
|
288 |
+
)
|
289 |
+
plt.fill_between(
|
290 |
+
range(1, n_samples),
|
291 |
+
min_biased_risk_error_per_samples,
|
292 |
+
max_biased_risk_error_per_samples,
|
293 |
+
alpha=0.3,
|
294 |
+
)
|
295 |
+
plt.ylim(
|
296 |
+
np.min(min_inference_risk_error_per_samples),
|
297 |
+
np.max(max_biased_risk_error_per_samples),
|
298 |
+
)
|
299 |
+
|
300 |
+
plt.hlines(y=0, xmin=0, xmax=n_samples, colors="black", linestyles="--", lw=0.3)
|
301 |
+
|
302 |
+
plt.xlabel("Number of samples")
|
303 |
+
plt.ylabel("Risk estimation error")
|
304 |
+
plt.legend()
|
305 |
+
plt.title(f"Risk estimation error at risk-level={rl}")
|
306 |
+
plt.gcf().set_size_inches(4, 3)
|
307 |
+
plt.legend(loc="lower right")
|
308 |
+
plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.svg"))
|
309 |
+
plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.png"))
|
310 |
+
plt.clf()
|
311 |
+
# plt.show()
|
312 |
+
|
313 |
+
|
314 |
+
def compute_stats(metrics, n_samples_mean_cost=4):
|
315 |
+
biased_risk_estimate = {}
|
316 |
+
for key in metrics["cost"].keys():
|
317 |
+
if key == "inference":
|
318 |
+
continue
|
319 |
+
risk = metrics["risk"][key]
|
320 |
+
mean_cost = metrics["cost"][key][:, :n_samples_mean_cost].mean(1)
|
321 |
+
risk_error = mean_cost - risk
|
322 |
+
biased_risk_estimate[key] = {}
|
323 |
+
(
|
324 |
+
biased_risk_estimate[key]["mean"],
|
325 |
+
biased_risk_estimate[key]["std"],
|
326 |
+
biased_risk_estimate[key]["ste"],
|
327 |
+
) = masked_mean_std_ste(risk_error, metrics["mask"].any(1))
|
328 |
+
|
329 |
+
(
|
330 |
+
biased_risk_estimate[key]["mean_abs"],
|
331 |
+
biased_risk_estimate[key]["std_abs"],
|
332 |
+
biased_risk_estimate[key]["ste_abs"],
|
333 |
+
) = masked_mean_std_ste(risk_error.abs(), metrics["mask"].any(1))
|
334 |
+
|
335 |
+
risk_stats = {}
|
336 |
+
for key in metrics["risk"].keys():
|
337 |
+
risk_stats[key] = {}
|
338 |
+
(
|
339 |
+
risk_stats[key]["mean"],
|
340 |
+
risk_stats[key]["std"],
|
341 |
+
risk_stats[key]["ste"],
|
342 |
+
) = masked_mean_std_ste(metrics["risk"][key], metrics["mask"].any(1))
|
343 |
+
|
344 |
+
cost_stats = {}
|
345 |
+
for key in metrics["cost"].keys():
|
346 |
+
cost_stats[key] = {}
|
347 |
+
(
|
348 |
+
cost_stats[key]["mean"],
|
349 |
+
cost_stats[key]["std"],
|
350 |
+
cost_stats[key]["ste"],
|
351 |
+
) = masked_mean_std_ste(
|
352 |
+
metrics["cost"][key], metrics["mask"].any(-1, keepdim=True)
|
353 |
+
)
|
354 |
+
|
355 |
+
distance_stats = {}
|
356 |
+
for key in metrics["distance"].keys():
|
357 |
+
distance_stats[key] = {"FDE": {}, "ADE": {}, "minFDE": {}, "minADE": {}}
|
358 |
+
(
|
359 |
+
distance_stats[key]["FDE"]["mean"],
|
360 |
+
distance_stats[key]["FDE"]["std"],
|
361 |
+
distance_stats[key]["FDE"]["ste"],
|
362 |
+
) = masked_mean_std_ste(
|
363 |
+
metrics["distance"][key][..., -1], metrics["mask"][:, None, -1]
|
364 |
+
)
|
365 |
+
mean_dist = masked_mean_dim(
|
366 |
+
metrics["distance"][key], metrics["mask"][:, None, :], -1
|
367 |
+
)
|
368 |
+
(
|
369 |
+
distance_stats[key]["ADE"]["mean"],
|
370 |
+
distance_stats[key]["ADE"]["std"],
|
371 |
+
distance_stats[key]["ADE"]["ste"],
|
372 |
+
) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1, keepdim=True))
|
373 |
+
for i in [6, 16, 32]:
|
374 |
+
distance_stats[key]["minFDE"][i] = {}
|
375 |
+
min_dist, _ = metrics["distance"][key][:, :i, -1].min(1)
|
376 |
+
(
|
377 |
+
distance_stats[key]["minFDE"][i]["mean"],
|
378 |
+
distance_stats[key]["minFDE"][i]["std"],
|
379 |
+
distance_stats[key]["minFDE"][i]["ste"],
|
380 |
+
) = masked_mean_std_ste(min_dist, metrics["mask"][:, -1])
|
381 |
+
distance_stats[key]["minADE"][i] = {}
|
382 |
+
mean_dist, _ = masked_mean_dim(
|
383 |
+
metrics["distance"][key][:, :i], metrics["mask"][:, None, :], -1
|
384 |
+
).min(1)
|
385 |
+
(
|
386 |
+
distance_stats[key]["minADE"][i]["mean"],
|
387 |
+
distance_stats[key]["minADE"][i]["std"],
|
388 |
+
distance_stats[key]["minADE"][i]["ste"],
|
389 |
+
) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1))
|
390 |
+
return {
|
391 |
+
"risk": risk_stats,
|
392 |
+
"biased_risk_estimate": biased_risk_estimate,
|
393 |
+
"cost": cost_stats,
|
394 |
+
"distance": distance_stats,
|
395 |
+
}
|
396 |
+
|
397 |
+
|
398 |
+
def print_stats(stats, n_samples_mean_cost=4):
|
399 |
+
slash = "\\"
|
400 |
+
brace_open = "{"
|
401 |
+
brace_close = "}"
|
402 |
+
print("\\begin{tabular}{lccccc}")
|
403 |
+
print("\\hline")
|
404 |
+
print(
|
405 |
+
f"Predictive Model & ${slash}sigma$ & minFDE(16) & FDE (1) & Risk est. error ({n_samples_mean_cost}) & Risk est. $|$error$|$ ({n_samples_mean_cost}) {slash}{slash}"
|
406 |
+
)
|
407 |
+
print("\\hline")
|
408 |
+
|
409 |
+
for key in stats["distance"].keys():
|
410 |
+
strg = (
|
411 |
+
f" ${stats['distance'][key]['minFDE'][16]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['minFDE'][16]['ste']:.2f}${brace_close}"
|
412 |
+
+ f"& ${stats['distance'][key]['FDE']['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['FDE']['ste']:.2f}${brace_close}"
|
413 |
+
)
|
414 |
+
|
415 |
+
if key == "inference":
|
416 |
+
strg = (
|
417 |
+
"Unbiased CVAE & "
|
418 |
+
+ f"{slash}scriptsize{brace_open}NA{brace_close} &"
|
419 |
+
+ strg
|
420 |
+
+ f"& {slash}scriptsize{brace_open}NA{brace_close} & {slash}scriptsize{brace_open}NA{brace_close} {slash}{slash}"
|
421 |
+
)
|
422 |
+
print(strg)
|
423 |
+
print("\\hline")
|
424 |
+
else:
|
425 |
+
strg = (
|
426 |
+
"Biased CVAE & "
|
427 |
+
+ f"{key[7:]} & "
|
428 |
+
+ strg
|
429 |
+
+ f"& ${stats['biased_risk_estimate'][key]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste']:.2f}${brace_close}"
|
430 |
+
+ f"& ${stats['biased_risk_estimate'][key]['mean_abs']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste_abs']:.2f}${brace_close}"
|
431 |
+
+ f"{slash}{slash}"
|
432 |
+
)
|
433 |
+
print(strg)
|
434 |
+
print("\\hline")
|
435 |
+
print("\\end{tabular}")
|
436 |
+
|
437 |
+
|
438 |
+
def main(
|
439 |
+
log_path,
|
440 |
+
force_recompute,
|
441 |
+
n_samples_risk=256,
|
442 |
+
n_samples_stats=32,
|
443 |
+
n_samples_plot=16,
|
444 |
+
args_to_parser=[],
|
445 |
+
):
|
446 |
+
# Overwrite sys.argv so it doesn't mess up the parser.
|
447 |
+
sys.argv = sys.argv[0:1] + args_to_parser
|
448 |
+
working_dir = os.path.dirname(os.path.realpath(__file__))
|
449 |
+
config_path = os.path.join(
|
450 |
+
working_dir, "..", "..", "risk_biased", "config", "learning_config.py"
|
451 |
+
)
|
452 |
+
waymo_config_path = os.path.join(
|
453 |
+
working_dir, "..", "..", "risk_biased", "config", "waymo_config.py"
|
454 |
+
)
|
455 |
+
cfg = config_argparse([config_path, waymo_config_path])
|
456 |
+
|
457 |
+
file_path = os.path.join(log_path, f"metrics_{cfg.load_from}.pickle")
|
458 |
+
fig_path = os.path.join(log_path, f"plots_{cfg.load_from}")
|
459 |
+
if not os.path.exists(fig_path):
|
460 |
+
os.makedirs(fig_path)
|
461 |
+
|
462 |
+
risk_levels = [0, 0.3, 0.5, 0.8, 0.95, 1]
|
463 |
+
cost = TTCCostTorch(TTCCostParams.from_config(cfg))
|
464 |
+
risk_estimator = get_risk_estimator(cfg.risk_estimator)
|
465 |
+
n_samples_mean_cost = 4
|
466 |
+
|
467 |
+
if not os.path.exists(file_path) or force_recompute:
|
468 |
+
with torch.no_grad():
|
469 |
+
if cfg.seed is not None:
|
470 |
+
seed_everything(cfg.seed)
|
471 |
+
|
472 |
+
predictor, dataloaders, cfg = load_from_config(cfg)
|
473 |
+
device = torch.device(cfg.gpus[0])
|
474 |
+
predictor = predictor.to(device)
|
475 |
+
|
476 |
+
val_loader = dataloaders.val_dataloader(shuffle=False, drop_last=False)
|
477 |
+
|
478 |
+
# This loops over batches in the validation dataset
|
479 |
+
beg = 0
|
480 |
+
metrics_all = None
|
481 |
+
for val_batch in tqdm(val_loader):
|
482 |
+
end = beg + val_batch[0].shape[0]
|
483 |
+
metrics = compute_metrics(
|
484 |
+
predictor=predictor,
|
485 |
+
batch=to_device(val_batch, device),
|
486 |
+
cost=cost,
|
487 |
+
risk_levels=risk_levels,
|
488 |
+
risk_estimator=risk_estimator,
|
489 |
+
dt=cfg.dt,
|
490 |
+
unnormalizer=dataloaders.unnormalize_trajectory,
|
491 |
+
n_samples_risk=n_samples_risk,
|
492 |
+
n_samples_stats=n_samples_stats,
|
493 |
+
)
|
494 |
+
if metrics_all is None:
|
495 |
+
metrics_all = metrics
|
496 |
+
else:
|
497 |
+
metrics_all = cat_metrics(metrics_all, metrics)
|
498 |
+
beg = end
|
499 |
+
with open(file_path, "wb") as handle:
|
500 |
+
pickle.dump(metrics_all, handle)
|
501 |
+
else:
|
502 |
+
print(f"Loading pre-computed metrics from {file_path}")
|
503 |
+
with open(file_path, "rb") as handle:
|
504 |
+
metrics_all = CPU_Unpickler(handle).load()
|
505 |
+
|
506 |
+
stats = compute_stats(metrics_all, n_samples_mean_cost=n_samples_mean_cost)
|
507 |
+
print_stats(stats, n_samples_mean_cost=n_samples_mean_cost)
|
508 |
+
plot_risk_error(metrics_all, risk_levels, risk_estimator, n_samples_plot, fig_path)
|
509 |
+
|
510 |
+
|
511 |
+
if __name__ == "__main__":
|
512 |
+
# main("./logs/002/", False, 256, 32, 16)
|
513 |
+
# Fire turns the main function into a script, then the risk_biased module argparse reads the other arguments.
|
514 |
+
# Thus, the way to use it would be:
|
515 |
+
# >python compute_stats.py <path to existing log dir> <Force recompute> <n_samples_risk> <n_samples_stats> <n_samples_plot> <other argparse arguments, example --load_from 1uail32>
|
516 |
+
|
517 |
+
# This is a hack to separate the Fire script args from the argparse arguments
|
518 |
+
args_to_parser = sys.argv[len(signature(main).parameters) :]
|
519 |
+
partial_main = partial(main, args_to_parser=args_to_parser)
|
520 |
+
sys.argv = sys.argv[: len(signature(main).parameters)]
|
521 |
+
|
522 |
+
# Runs the main as a script
|
523 |
+
fire.Fire(partial_main)
|
scripts/eval_scripts/draw_cost.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from mmcv import Config
|
5 |
+
import numpy as np
|
6 |
+
from pytorch_lightning.utilities.seed import seed_everything
|
7 |
+
|
8 |
+
from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
|
9 |
+
from risk_biased.scene_dataset.scene_plotter import ScenePlotter
|
10 |
+
from risk_biased.utils.cost import (
|
11 |
+
DistanceCostNumpy,
|
12 |
+
DistanceCostParams,
|
13 |
+
TTCCostNumpy,
|
14 |
+
TTCCostParams,
|
15 |
+
)
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
working_dir = os.path.dirname(os.path.realpath(os.path.join(__file__, "..")))
|
19 |
+
config_path = os.path.join(
|
20 |
+
working_dir, "..", "risk_biased", "config", "learning_config.py"
|
21 |
+
)
|
22 |
+
config = Config.fromfile(config_path)
|
23 |
+
if config.seed is not None:
|
24 |
+
seed_everything(config.seed)
|
25 |
+
ped_speed = 2
|
26 |
+
is_torch = False
|
27 |
+
|
28 |
+
fig, ax = plt.subplots(
|
29 |
+
3, 4, sharex=True, sharey=True, tight_layout=True, subplot_kw={"aspect": 1}
|
30 |
+
)
|
31 |
+
|
32 |
+
scene_params = RandomSceneParams.from_config(config)
|
33 |
+
scene_params.batch_size = 1
|
34 |
+
for ii in range(9):
|
35 |
+
test_scene = RandomScene(
|
36 |
+
scene_params,
|
37 |
+
is_torch=is_torch,
|
38 |
+
)
|
39 |
+
dist_cost = DistanceCostNumpy(DistanceCostParams.from_config(config))
|
40 |
+
ttc_cost = TTCCostNumpy(TTCCostParams.from_config(config))
|
41 |
+
|
42 |
+
nx = 1000
|
43 |
+
ny = 100
|
44 |
+
x, y = np.meshgrid(
|
45 |
+
np.linspace(-test_scene.ego_length, test_scene.road_length, nx),
|
46 |
+
np.linspace(test_scene.bottom, test_scene.top, ny),
|
47 |
+
)
|
48 |
+
|
49 |
+
i = 2 - (int(ii >= 6) + int(ii >= 3))
|
50 |
+
j = ii % 3
|
51 |
+
vx = float(ii % 3 - 1)
|
52 |
+
vy = float((ii >= 6)) - float(ii <= 2)
|
53 |
+
print(f"horizontal velocity {vx}")
|
54 |
+
print(f"vertical velocity {vy}")
|
55 |
+
norm = np.maximum(np.sqrt(vx * vx + vy * vy), 1)
|
56 |
+
vx = vx / norm * np.ones([nx * ny, 1])
|
57 |
+
vy = vy / norm * np.ones([nx * ny, 1])
|
58 |
+
v_ped = ped_speed * np.stack((vx, vy), -1)
|
59 |
+
v_ego = np.array([[[test_scene.ego_ref_speed, 0]]])
|
60 |
+
|
61 |
+
p_init = np.stack((x, y), -1).reshape((nx * ny, 2))
|
62 |
+
p_final = p_init + v_ped[:, 0, :] * test_scene.time_scene
|
63 |
+
len_traj = 30
|
64 |
+
ped_trajs = np.linspace(p_init, p_final, len_traj, axis=1)
|
65 |
+
ego_traj = np.linspace(
|
66 |
+
[[0, 0]],
|
67 |
+
[test_scene.ego_ref_speed * test_scene.time_scene, 0],
|
68 |
+
len_traj,
|
69 |
+
axis=1,
|
70 |
+
)
|
71 |
+
|
72 |
+
cost, _ = ttc_cost(ego_traj, ped_trajs, v_ego, v_ped)
|
73 |
+
|
74 |
+
cost = cost.reshape(ny, nx)
|
75 |
+
colorbar = ax[i][j].pcolormesh(x, y, cost, cmap="RdBu")
|
76 |
+
plotter = ScenePlotter(test_scene, ax=ax[i][j])
|
77 |
+
plotter.plot_road()
|
78 |
+
|
79 |
+
fig.subplots_adjust(wspace=0.1, hspace=0.1)
|
80 |
+
fig.tight_layout()
|
81 |
+
fig.colorbar(colorbar, ax=ax.ravel().tolist())
|
82 |
+
for a in ax[:, -1]:
|
83 |
+
a.remove()
|
84 |
+
plt.show()
|