{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "71fbfca2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
      "================================================================================\n",
      "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
      "CUDA SETUP: Detected CUDA version 117\n",
      "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import PeftModel, PeftConfig\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-7b1\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-7b1\"\n",
    "dataset_name = \"twitter_complaints\"\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 1e-3\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1a3648b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe12d4d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10cabeec92ab428f9a660ebaecbaf865",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a344e989ab34c71b230acee68b477e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/4 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2795b9d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "test_dataset = processed_datasets[\"test\"]\n",
    "\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "print(next(iter(eval_dataloader)))\n",
    "print(next(iter(test_dataloader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42b14a11",
   "metadata": {},
   "source": [
    "You can load model from hub or local\n",
    "\n",
    "- Load model from Hugging Face Hub, you can change to your own model id\n",
    "```python\n",
    "peft_model_id = \"username/twitter_complaints_bigscience_bloomz-7b1_LORA_CAUSAL_LM\"\n",
    "```\n",
    "- Or load model form local\n",
    "```python\n",
    "peft_model_id = \"twitter_complaints_bigscience_bloomz-7b1_LORA_CAUSAL_LM\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9caac014",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sourab/pet/src/peft/tuners/lora.py:143: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc38030106a14173a1363eb1ee388eda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/15.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "max_memory = {0: \"1GIB\", 1: \"1GIB\", 2: \"2GIB\", 3: \"10GIB\", \"cpu\": \"30GB\"}\n",
    "peft_model_id = \"smangrul/twitter_complaints_bigscience_bloomz-7b1_LORA_CAUSAL_LM\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map=\"auto\", max_memory=max_memory)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id, device_map=\"auto\", max_memory=max_memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "6fac10b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2a08ee6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'base_model.model.transformer.word_embeddings': 3,\n",
       " 'base_model.model.lm_head': 3,\n",
       " 'base_model.model.transformer.word_embeddings_layernorm': 3,\n",
       " 'base_model.model.transformer.h.0': 3,\n",
       " 'base_model.model.transformer.h.1': 3,\n",
       " 'base_model.model.transformer.h.2': 3,\n",
       " 'base_model.model.transformer.h.3': 3,\n",
       " 'base_model.model.transformer.h.4': 3,\n",
       " 'base_model.model.transformer.h.5': 3,\n",
       " 'base_model.model.transformer.h.6': 3,\n",
       " 'base_model.model.transformer.h.7': 3,\n",
       " 'base_model.model.transformer.h.8': 'cpu',\n",
       " 'base_model.model.transformer.h.9': 'cpu',\n",
       " 'base_model.model.transformer.h.10': 'cpu',\n",
       " 'base_model.model.transformer.h.11': 'cpu',\n",
       " 'base_model.model.transformer.h.12': 'cpu',\n",
       " 'base_model.model.transformer.h.13': 'cpu',\n",
       " 'base_model.model.transformer.h.14': 'cpu',\n",
       " 'base_model.model.transformer.h.15': 'cpu',\n",
       " 'base_model.model.transformer.h.16': 'cpu',\n",
       " 'base_model.model.transformer.h.17': 'cpu',\n",
       " 'base_model.model.transformer.h.18': 'cpu',\n",
       " 'base_model.model.transformer.h.19': 'cpu',\n",
       " 'base_model.model.transformer.h.20': 'cpu',\n",
       " 'base_model.model.transformer.h.21': 'cpu',\n",
       " 'base_model.model.transformer.h.22': 'cpu',\n",
       " 'base_model.model.transformer.h.23': 'cpu',\n",
       " 'base_model.model.transformer.h.24': 'cpu',\n",
       " 'base_model.model.transformer.h.25': 'cpu',\n",
       " 'base_model.model.transformer.h.26': 'cpu',\n",
       " 'base_model.model.transformer.h.27': 'cpu',\n",
       " 'base_model.model.transformer.h.28': 'cpu',\n",
       " 'base_model.model.transformer.h.29': 'cpu',\n",
       " 'base_model.model.transformer.ln_f': 'cpu'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.hf_device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b33be5e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again.\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566, 216744,     38,   1316,     54,  42705,\n",
      "          32465,  52166,   9440,   1809,   3784,  88483,   9411,    368,  84342,\n",
      "           4451,     17,    473,   2152,  11705,  82406,    267,  51591,   5734,\n",
      "             17,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566, 216744,     38,   1316,     54,  42705,\n",
      "          32465,  52166,   9440,   1809,   3784,  88483,   9411,    368,  84342,\n",
      "           4451,     17,    473,   2152,  11705,  82406,    267,  51591,   5734,\n",
      "             17,  77658,    915,    210,  16449,   5952,      3,      3,      3,\n",
      "              3,      3,      3,      3,      3]])\n",
      "['Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label : complaint']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 89\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b6d6cd5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:42<00:00, 14.70s/it]\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "eval_preds = []\n",
    "for _, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    batch = {k: v for k, v in batch.items() if k != \"labels\"}\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**batch, max_new_tokens=10)\n",
    "    preds = outputs[:, max_length:].detach().cpu().numpy()\n",
    "    eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "61264abe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy=100.0\n",
      "eval_preds[:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint']\n",
      "dataset['train'][label_column][:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint']\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "for pred, true in zip(eval_preds, dataset[\"train\"][label_column]):\n",
    "    if pred.strip() == true.strip():\n",
    "        correct += 1\n",
    "    total += 1\n",
    "accuracy = correct / total * 100\n",
    "print(f\"{accuracy=}\")\n",
    "print(f\"{eval_preds[:10]=}\")\n",
    "print(f\"{dataset['train'][label_column][:10]=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a70802a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "test_preds = []\n",
    "\n",
    "for _, batch in enumerate(tqdm(test_dataloader)):\n",
    "    batch = {k: v for k, v in batch.items() if k != \"labels\"}\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**batch, max_new_tokens=10)\n",
    "    preds = outputs[:, max_length:].detach().cpu().numpy()\n",
    "    test_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))\n",
    "    if len(test_preds) > 100:\n",
    "        break\n",
    "test_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c4ad9c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}