{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fbfca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM\n",
    "from peft import PeftModel, PeftConfig\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "batch_size = 8\n",
    "\n",
    "peft_model_id = \"smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cc55820a",
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_model_id = \"smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM\"\n",
    "max_memory = {0: \"6GIB\", 1: \"0GIB\", 2: \"0GIB\", 3: \"0GIB\", 4: \"0GIB\", \"cpu\": \"30GB\"}\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, device_map=\"auto\", max_memory=max_memory)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id, device_map=\"auto\", max_memory=max_memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1a3648b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe12d4d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = examples[text_column]\n",
    "    targets = examples[label_column]\n",
    "    model_inputs = tokenizer(inputs, truncation=True)\n",
    "    labels = tokenizer(\n",
    "        targets, max_length=target_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n",
    "    )\n",
    "    labels = labels[\"input_ids\"]\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    model_inputs[\"labels\"] = labels\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=True,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "test_dataset = processed_datasets[\"test\"]\n",
    "\n",
    "\n",
    "def collate_fn(examples):\n",
    "    return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b33be5e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@NYTsupport i have complained a dozen times & yet my papers are still thrown FAR from my door. Why is this so hard to resolve?\n",
      "{'input_ids': tensor([[25335,  1499,     3,    10,  3320, 12056,   382, 20390,     3,    23,\n",
      "            43, 25932,     3,     9,  9611,   648,     3,   184,  4624,   117,\n",
      "           780,    82,  5778,    33,   341,     3, 12618,   377,  4280,    45,\n",
      "            82,  1365,     5,  1615,    19,    48,    78,   614,    12,  7785,\n",
      "            58, 16229,     3,    10,     3,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[    0, 10394,     1]], device='cuda:0')\n",
      "['complaint']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 15\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids=inputs[\"input_ids\"].to(\"cuda\"), max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b6d6cd5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                    | 0/7 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:10<00:00,  1.48s/it]\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "eval_preds = []\n",
    "for _, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    batch = {k: v.to(\"cuda\") for k, v in batch.items() if k != \"labels\"}\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**batch, max_new_tokens=10)\n",
    "    preds = outputs.detach().cpu().numpy()\n",
    "    eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "61264abe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy=100.0\n",
      "eval_preds[:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint']\n",
      "dataset['train'][label_column][:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint']\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "for pred, true in zip(eval_preds, dataset[\"train\"][label_column]):\n",
    "    if pred.strip() == true.strip():\n",
    "        correct += 1\n",
    "    total += 1\n",
    "accuracy = correct / total * 100\n",
    "print(f\"{accuracy=}\")\n",
    "print(f\"{eval_preds[:10]=}\")\n",
    "print(f\"{dataset['train'][label_column][:10]=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a70802a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "test_preds = []\n",
    "\n",
    "for _, batch in enumerate(tqdm(test_dataloader)):\n",
    "    batch = {k: v for k, v in batch.items() if k != \"labels\"}\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**batch, max_new_tokens=10)\n",
    "    preds = outputs.detach().cpu().numpy()\n",
    "    test_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))\n",
    "    if len(test_preds) > 100:\n",
    "        break\n",
    "test_preds"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5 (v3.10.5:f377153967, Jun  6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}