Spaces:
Running
Running
import pickle | |
from typing import Any, Dict | |
import torch.distributed.checkpoint.stateful | |
import torchdata.stateful_dataloader | |
from ..logging import get_logger | |
logger = get_logger() | |
class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful): | |
def __init__( | |
self, | |
rank: int, | |
dataset: torch.utils.data.IterableDataset, | |
batch_size: int = 1, | |
num_workers: int = 0, | |
collate_fn=None, | |
) -> None: | |
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn) | |
self._dp_rank = rank | |
self._rank_id = f"dp_rank_{rank}" | |
def state_dict(self) -> Dict[str, Any]: | |
# Store state only for dp rank to avoid replicating the same state across other dimensions | |
return {self._rank_id: pickle.dumps(super().state_dict())} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
# State being empty is valid | |
if not state_dict: | |
return | |
if self._rank_id not in state_dict: | |
logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}") | |
return | |
super().load_state_dict(pickle.loads(state_dict[self._rank_id])) | |