{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58ff91ca-ce92-43d0-ae8b-4e9e89e193f6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import set_seed, AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from peft import get_peft_model, MultitaskPromptTuningConfig, TaskType, MultitaskPromptTuningInit\n",
    "\n",
    "set_seed(42)\n",
    "\n",
    "model_name = \"google/flan-t5-base\"\n",
    "\n",
    "peft_config = MultitaskPromptTuningConfig(\n",
    "    tokenizer_name_or_path=model_name,\n",
    "    num_tasks=2,\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=MultitaskPromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=50,\n",
    "    num_transformer_submodules=1,\n",
    "    prompt_tuning_init_text=\"classify the following into either positive or negative, or entailment, neutral or contradiction:\",\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
    "model = get_peft_model(model, peft_config)\n",
    "\n",
    "model = model.cuda()\n",
    "\n",
    "\n",
    "def send_to_device(batch):\n",
    "    for i in batch:\n",
    "        batch[i] = batch[i].cuda()\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb112bc1-ffaf-49fa-a216-0d601ec304ee",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_sst2(split: str):\n",
    "    examples = load_dataset(\"sst2\")[split]\n",
    "    result_examples = []\n",
    "    for example in examples:\n",
    "        result_examples.append({})\n",
    "\n",
    "        result_examples[-1][\"input\"] = example[\"sentence\"].strip() + \"</s>\"\n",
    "        result_examples[-1][\"output\"] = (\n",
    "            f\"positive{tokenizer.eos_token}\" if example[\"label\"] == 1 else f\"negative{tokenizer.eos_token}\"\n",
    "        )\n",
    "        result_examples[-1][\"task_id\"] = 0\n",
    "\n",
    "    return result_examples\n",
    "\n",
    "\n",
    "def get_mnli(split: str):\n",
    "    examples = load_dataset(\"multi_nli\")[split]\n",
    "    result_examples = []\n",
    "    for example in examples:\n",
    "        result_examples.append({})\n",
    "\n",
    "        result_examples[-1][\"input\"] = example[\"premise\"].strip() + \" \" + example[\"hypothesis\"].strip() + \"</s>\"\n",
    "\n",
    "        if example[\"label\"] == 0:\n",
    "            result_examples[-1][\"output\"] = f\"entailment{tokenizer.eos_token}\"\n",
    "        elif example[\"label\"] == 1:\n",
    "            result_examples[-1][\"output\"] = f\"neutral{tokenizer.eos_token}\"\n",
    "        else:\n",
    "            result_examples[-1][\"output\"] = f\"contradiction{tokenizer.eos_token}\"\n",
    "\n",
    "        result_examples[-1][\"task_id\"] = 1\n",
    "\n",
    "    return result_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5a16ec4-8fef-4ba9-95b6-a661eb51e50c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch\n",
    "\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "    def __init__(self, split: str, mode: str = \"source\") -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        if split == \"train\":\n",
    "            if mode == \"source\":\n",
    "                self.examples = get_sst2(split) + get_mnli(split)\n",
    "            elif mode == \"target\":\n",
    "                self.examples = get_sst2(split)\n",
    "        if split == \"val\":\n",
    "            self.examples = get_sst2(\"validation\")\n",
    "        if split == \"test\":\n",
    "            self.examples = get_sst2(\"validation\")\n",
    "\n",
    "    def __getitem__(self, index) -> dict:\n",
    "        return self.examples[index]\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.examples)\n",
    "\n",
    "    def __getitem__(self, index) -> dict:\n",
    "        return self.examples[index]\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.examples)\n",
    "\n",
    "\n",
    "def collate_fn(batch: dict) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "    input = [i[\"input\"] for i in batch]\n",
    "    input = tokenizer(input, add_special_tokens=False, return_tensors=\"pt\", padding=True)\n",
    "\n",
    "    output = [i[\"output\"] for i in batch]\n",
    "    output = tokenizer(output, add_special_tokens=False, return_tensors=\"pt\", padding=True).input_ids\n",
    "    output[output == tokenizer.pad_token_id] = -100\n",
    "\n",
    "    task_ids = [i[\"task_id\"] for i in batch]\n",
    "    task_ids = torch.tensor(task_ids)\n",
    "\n",
    "    return {\n",
    "        \"input_ids\": input.input_ids,\n",
    "        \"attention_mask\": input.attention_mask,\n",
    "        \"labels\": output,\n",
    "        \"task_ids\": task_ids,\n",
    "    }\n",
    "\n",
    "\n",
    "train = DataLoader(MyDataset(\"train\"), shuffle=True, batch_size=8, collate_fn=collate_fn)\n",
    "val = DataLoader(MyDataset(\"val\"), shuffle=False, batch_size=8, collate_fn=collate_fn)\n",
    "test = DataLoader(MyDataset(\"test\"), shuffle=False, batch_size=8, collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe0aec7b-f61e-4b00-a90e-c1201dc1f84c",
   "metadata": {},
   "source": [
    "## source training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cceecc94-f43a-4f62-8d45-926f2f02f36d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.optim.adamw import AdamW\n",
    "from transformers import get_cosine_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eae5516b-73ab-44a8-a083-4e8de6127f30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "POSITIVE_TOKEN_ID = tokenizer(\" positive\", add_special_tokens=False)[\"input_ids\"][0]\n",
    "NEGATIVE_TOKEN_ID = tokenizer(\" negative\", add_special_tokens=False)[\"input_ids\"][0]\n",
    "\n",
    "\n",
    "def classify(batch):\n",
    "    batch = send_to_device(batch)\n",
    "    # we pass labels here since we need to generate and peft doesn't support generation yet.\n",
    "    # No clue how to get around this\n",
    "    scores = model(**batch).logits\n",
    "    preds = []\n",
    "    for i in range(scores.shape[0]):\n",
    "        if scores[i, 0, POSITIVE_TOKEN_ID] > scores[i, 0, NEGATIVE_TOKEN_ID]:\n",
    "            preds.append(POSITIVE_TOKEN_ID)\n",
    "        else:\n",
    "            preds.append(NEGATIVE_TOKEN_ID)\n",
    "    return preds\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def evaluate(model, data):\n",
    "    loss = 0\n",
    "    preds = []\n",
    "    golds = []\n",
    "\n",
    "    for batch in tqdm(data):\n",
    "        batch = send_to_device(batch)\n",
    "        loss += model(**batch).loss\n",
    "        golds.extend(batch[\"labels\"][:, 0].tolist())\n",
    "        preds.extend(classify(batch))\n",
    "\n",
    "    return loss / len(val), f1_score(golds, preds, pos_label=POSITIVE_TOKEN_ID)\n",
    "\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train))\n",
    "\n",
    "n = 1000\n",
    "step = 0\n",
    "train_ = tqdm(train)\n",
    "\n",
    "val_loss, f1 = evaluate(model, val)\n",
    "print(\n",
    "    f\"\"\"\n",
    "before source training\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")\n",
    "\n",
    "for batch in train_:\n",
    "    if step % n == 0:\n",
    "        val_loss, f1 = evaluate(model, val)\n",
    "        print(\n",
    "            f\"\"\"\n",
    "step = {step}\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    "        )\n",
    "        model.save_pretrained(f\"checkpoints_source/{step}\")\n",
    "\n",
    "    step += 1\n",
    "    batch = send_to_device(batch)\n",
    "    loss = model(**batch).loss\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    train_.set_postfix(train_loss=loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74168ef3-66f3-41a7-a40b-7840b103fbf9",
   "metadata": {},
   "source": [
    "## target training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09fd456-163e-4dc1-b24d-f2d0d349036c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train = DataLoader(MyDataset(\"train\", \"target\"), shuffle=True, batch_size=8, collate_fn=collate_fn)\n",
    "val = DataLoader(MyDataset(\"val\", \"target\"), shuffle=False, batch_size=8, collate_fn=collate_fn)\n",
    "test = DataLoader(MyDataset(\"test\", \"target\"), shuffle=False, batch_size=8, collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a539944-f16c-4c3f-bb4a-7b5d9a6042e2",
   "metadata": {},
   "source": [
    "#### create a fresh model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5520d904-aa6c-4654-9335-ed4e7d76cba2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "peft_config = MultitaskPromptTuningConfig(\n",
    "    tokenizer_name_or_path=model_name,\n",
    "    num_tasks=1,\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=MultitaskPromptTuningInit.EXACT_SOURCE_TASK,\n",
    "    prompt_tuning_init_state_dict_path=\"checkpoints_source/50000/adapter_model.bin\",\n",
    "    num_virtual_tokens=50,\n",
    "    num_transformer_submodules=1,\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
    "model = get_peft_model(model, peft_config)\n",
    "\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa39c2d-d1c5-4ed4-90f8-26e8e324371c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train))\n",
    "\n",
    "n = 1000\n",
    "step = 0\n",
    "train_ = tqdm(train)\n",
    "\n",
    "val_loss, f1 = evaluate(model, val)\n",
    "print(\n",
    "    f\"\"\"\n",
    "before target training\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")\n",
    "\n",
    "for batch in train_:\n",
    "    if step % n == 0:\n",
    "        val_loss, f1 = evaluate(model, val)\n",
    "        print(\n",
    "            f\"\"\"\n",
    "step = {step}\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    "        )\n",
    "        model.save_pretrained(f\"checkpoints_target/{step}\")\n",
    "\n",
    "    step += 1\n",
    "    batch = send_to_device(batch)\n",
    "    loss = model(**batch).loss\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    train_.set_postfix(train_loss=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a6eeda-1e09-49a6-8845-cd96c8573145",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load last checkpoint for now\n",
    "from peft import set_peft_model_state_dict\n",
    "\n",
    "sd_6000 = torch.load(\"checkpoints_target/6000/adapter_model.bin\")\n",
    "set_peft_model_state_dict(model, sd_6000)\n",
    "\n",
    "# evaluate val\n",
    "val_loss, f1 = evaluate(model, val)\n",
    "print(\n",
    "    f\"\"\"\n",
    "final\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")\n",
    "\n",
    "# evaluate test\n",
    "test_loss, f1 = evaluate(model, test)\n",
    "print(\n",
    "    f\"\"\"\n",
    "final\n",
    "test loss = {test_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}