optimize dataloading to use cache, fix model token embedding sizes
Browse files- src/axolotl/utils/data.py +72 -14
- src/axolotl/utils/models.py +2 -0
src/axolotl/utils/data.py
CHANGED
|
@@ -31,13 +31,7 @@ from axolotl.prompters import (
|
|
| 31 |
)
|
| 32 |
|
| 33 |
|
| 34 |
-
def
|
| 35 |
-
max_packed_sequence_len = (
|
| 36 |
-
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
| 37 |
-
)
|
| 38 |
-
max_packed_sequence_len = min(
|
| 39 |
-
max_packed_sequence_len, cfg.sequence_len
|
| 40 |
-
) # make sure we don't accidentally set it larger than sequence_len
|
| 41 |
ds_hash = str(
|
| 42 |
md5(
|
| 43 |
(
|
|
@@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 54 |
)
|
| 55 |
|
| 56 |
if any(prepared_ds_path.glob("*")):
|
| 57 |
-
logging.info(f"Loading prepared dataset from disk
|
| 58 |
dataset = load_from_disk(str(prepared_ds_path))
|
| 59 |
logging.info("Prepared dataset loaded from disk...")
|
| 60 |
else:
|
|
@@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
| 153 |
)
|
| 154 |
dataset.save_to_disk(prepared_ds_path)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
if cfg.max_packed_sequence_len is not None:
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
)
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
| 166 |
logging.info(
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
|
| 34 |
+
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
ds_hash = str(
|
| 36 |
md5(
|
| 37 |
(
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
if any(prepared_ds_path.glob("*")):
|
| 51 |
+
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
| 52 |
dataset = load_from_disk(str(prepared_ds_path))
|
| 53 |
logging.info("Prepared dataset loaded from disk...")
|
| 54 |
else:
|
|
|
|
| 147 |
)
|
| 148 |
dataset.save_to_disk(prepared_ds_path)
|
| 149 |
|
| 150 |
+
return dataset
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
| 154 |
+
max_packed_sequence_len = (
|
| 155 |
+
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
| 156 |
+
)
|
| 157 |
+
max_packed_sequence_len = min(
|
| 158 |
+
max_packed_sequence_len, cfg.sequence_len
|
| 159 |
+
) # make sure we don't accidentally set it larger than sequence_len
|
| 160 |
+
|
| 161 |
if cfg.max_packed_sequence_len is not None:
|
| 162 |
+
# see if we can go ahead and load the stacked dataset
|
| 163 |
+
|
| 164 |
+
ds_hash = str(
|
| 165 |
+
md5(
|
| 166 |
+
(
|
| 167 |
+
str(cfg.sequence_len)
|
| 168 |
+
+ "@"
|
| 169 |
+
+ str(max_packed_sequence_len)
|
| 170 |
+
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
| 171 |
+
).encode("utf-8")
|
| 172 |
+
).hexdigest()
|
| 173 |
+
)
|
| 174 |
+
prepared_ds_path = (
|
| 175 |
+
Path(cfg.dataset_prepared_path) / ds_hash
|
| 176 |
+
if cfg.dataset_prepared_path
|
| 177 |
+
else Path(default_dataset_prepared_path) / ds_hash
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if any(prepared_ds_path.glob("*")):
|
| 181 |
+
logging.info(
|
| 182 |
+
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
| 183 |
+
)
|
| 184 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
| 185 |
+
logging.info("Prepared packed dataset loaded from disk...")
|
| 186 |
+
else:
|
| 187 |
+
dataset = load_tokenized_prepared_datasets(
|
| 188 |
+
tokenizer, cfg, default_dataset_prepared_path
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
constant_len_dataset = ConstantLengthDataset(
|
| 192 |
+
tokenizer,
|
| 193 |
+
[dataset],
|
| 194 |
+
seq_length=max_packed_sequence_len,
|
| 195 |
+
)
|
| 196 |
+
logging.info(
|
| 197 |
+
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
| 198 |
+
)
|
| 199 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
| 200 |
+
|
| 201 |
+
if cfg.local_rank == 0:
|
| 202 |
+
logging.info(
|
| 203 |
+
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
| 204 |
+
)
|
| 205 |
+
dataset.save_to_disk(prepared_ds_path)
|
| 206 |
+
else:
|
| 207 |
+
dataset = load_tokenized_prepared_datasets(
|
| 208 |
+
tokenizer, cfg, default_dataset_prepared_path
|
| 209 |
)
|
| 210 |
+
|
| 211 |
+
# filter out bad data
|
| 212 |
+
dataset = Dataset.from_list(
|
| 213 |
+
[
|
| 214 |
+
d
|
| 215 |
+
for d in dataset
|
| 216 |
+
if len(d["input_ids"]) > cfg.sequence_len
|
| 217 |
+
and len(d["input_ids"]) > 0
|
| 218 |
+
and len(d["input_ids"]) == len(d["attention_mask"])
|
| 219 |
+
and len(d["input_ids"]) == len(d["labels"])
|
| 220 |
+
]
|
| 221 |
+
)
|
| 222 |
|
| 223 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
| 224 |
logging.info(
|
src/axolotl/utils/models.py
CHANGED
|
@@ -181,6 +181,8 @@ def load_model(
|
|
| 181 |
for k, v in cfg.tokens.items():
|
| 182 |
tokenizer.add_special_tokens({k: v})
|
| 183 |
|
|
|
|
|
|
|
| 184 |
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
| 185 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 186 |
model = prepare_model_for_int8_training(model)
|
|
|
|
| 181 |
for k, v in cfg.tokens.items():
|
| 182 |
tokenizer.add_special_tokens({k: v})
|
| 183 |
|
| 184 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 185 |
+
|
| 186 |
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
| 187 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 188 |
model = prepare_model_for_int8_training(model)
|