jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
2.53 kB
from typing import Any, Dict, List, Tuple
import torch
class ResolutionSampler:
def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None:
self.batch_size = batch_size
self.dim_keys = dim_keys
assert dim_keys is not None, "dim_keys must be provided"
self._chosen_leader_key = None
self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {}
self._satisfied_buckets: List[Dict[Any, Any]] = []
def consume(self, *dict_items: Dict[Any, Any]) -> None:
if self._chosen_leader_key is None:
self._determine_leader_item(*dict_items)
self._update_buckets(*dict_items)
def get_batch(self) -> List[Dict[str, Any]]:
return list(zip(*self._satisfied_buckets.pop(-1)))
@property
def is_ready(self) -> bool:
return len(self._satisfied_buckets) > 0
def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None:
num_observed = 0
for dict_item in dict_items:
for key in self.dim_keys.keys():
if key in dict_item.keys():
self._chosen_leader_key = key
if not torch.is_tensor(dict_item[key]):
raise ValueError(f"Leader key {key} must be a tensor")
num_observed += 1
if num_observed > 1:
raise ValueError(
f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys"
)
if self._chosen_leader_key is None:
raise ValueError("No leader key found in provided list of data dictionaries")
def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None:
chosen_value = [
dict_item[self._chosen_leader_key]
for dict_item in dict_items
if self._chosen_leader_key in dict_item.keys()
]
if len(chosen_value) == 0:
raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries")
chosen_value = chosen_value[0]
dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key])
if dims not in self._unsatisfied_buckets:
self._unsatisfied_buckets[dims] = []
self._unsatisfied_buckets[dims].append(dict_items)
if len(self._unsatisfied_buckets[dims]) == self.batch_size:
self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims))