jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
1.3 kB
import pickle
from typing import Any, Dict
import torch.distributed.checkpoint.stateful
import torchdata.stateful_dataloader
from ..logging import get_logger
logger = get_logger()
class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
def __init__(
self,
rank: int,
dataset: torch.utils.data.IterableDataset,
batch_size: int = 1,
num_workers: int = 0,
collate_fn=None,
) -> None:
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
self._dp_rank = rank
self._rank_id = f"dp_rank_{rank}"
def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return
if self._rank_id not in state_dict:
logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
return
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))