{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f93b7d1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:49:56.334329Z",
     "start_time": "2023-05-30T09:49:54.494916Z"
    }
   },
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, GenerationConfig\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpeft\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1055\u001b[0m, in \u001b[0;36m_handle_fromlist\u001b[0;34m(module, fromlist, import_, recursive)\u001b[0m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/utils/import_utils.py:1076\u001b[0m, in \u001b[0;36m_LazyModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1074\u001b[0m     value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_module(name)\n\u001b[1;32m   1075\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_class_to_module\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m-> 1076\u001b[0m     module \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1077\u001b[0m     value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[1;32m   1078\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/utils/import_utils.py:1086\u001b[0m, in \u001b[0;36m_LazyModule._get_module\u001b[0;34m(self, module_name)\u001b[0m\n\u001b[1;32m   1084\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_module\u001b[39m(\u001b[38;5;28mself\u001b[39m, module_name: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m   1085\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1086\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1087\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m   1088\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m   1089\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to import \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodule_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m because of the following error (look up to see its\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1090\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m traceback):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1091\u001b[0m         ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/importlib/__init__.py:127\u001b[0m, in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m    125\u001b[0m             \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m    126\u001b[0m         level \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m--> 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/training_args_seq2seq.py:21\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Optional, Union\n\u001b[1;32m     20\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgeneration\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfiguration_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GenerationConfig\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtraining_args\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TrainingArguments\n\u001b[1;32m     22\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m add_start_docstrings\n\u001b[1;32m     25\u001b[0m logger \u001b[38;5;241m=\u001b[39m logging\u001b[38;5;241m.\u001b[39mgetLogger(\u001b[38;5;18m__name__\u001b[39m)\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/training_args.py:29\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, Dict, List, Optional, Union\n\u001b[1;32m     27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpackaging\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m version\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdebug_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DebugOption\n\u001b[1;32m     30\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtrainer_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     31\u001b[0m     EvaluationStrategy,\n\u001b[1;32m     32\u001b[0m     FSDPOption,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     36\u001b[0m     ShardedDDPOption,\n\u001b[1;32m     37\u001b[0m )\n\u001b[1;32m     38\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     39\u001b[0m     ExplicitEnum,\n\u001b[1;32m     40\u001b[0m     cached_property,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     53\u001b[0m     requires_backends,\n\u001b[1;32m     54\u001b[0m )\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/debug_utils.py:21\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ExplicitEnum, is_torch_available, logging\n\u001b[1;32m     20\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_torch_available():\n\u001b[0;32m---> 21\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m     24\u001b[0m logger \u001b[38;5;241m=\u001b[39m logging\u001b[38;5;241m.\u001b[39mget_logger(\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m     27\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mDebugUnderflowOverflow\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/__init__.py:1465\u001b[0m\n\u001b[1;32m   1463\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m library\n\u001b[1;32m   1464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m TYPE_CHECKING:\n\u001b[0;32m-> 1465\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _meta_registrations\n\u001b[1;32m   1467\u001b[0m \u001b[38;5;66;03m# Enable CUDA Sanitizer\u001b[39;00m\n\u001b[1;32m   1468\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTORCH_CUDA_SANITIZER\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39menviron:\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_meta_registrations.py:7\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims_common\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Tensor\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_decomp\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _add_op_to_registry, global_decomposition_table, meta_table\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_ops\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpOverload\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_decomp/__init__.py:169\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m decompositions\n\u001b[1;32m    168\u001b[0m \u001b[38;5;66;03m# populate the table\u001b[39;00m\n\u001b[0;32m--> 169\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_decomp\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdecompositions\u001b[39;00m\n\u001b[1;32m    170\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_refs\u001b[39;00m\n\u001b[1;32m    172\u001b[0m \u001b[38;5;66;03m# This list was copied from torch/_inductor/decomposition.py\u001b[39;00m\n\u001b[1;32m    173\u001b[0m \u001b[38;5;66;03m# excluding decompositions that results in prim ops\u001b[39;00m\n\u001b[1;32m    174\u001b[0m \u001b[38;5;66;03m# Resulting opset of decomposition is core aten ops\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_decomp/decompositions.py:10\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Callable, cast, Iterable, List, Optional, Tuple, Union\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mprims\u001b[39;00m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims_common\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mF\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_prims/__init__.py:33\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims_common\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     18\u001b[0m     check,\n\u001b[1;32m     19\u001b[0m     Dim,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     30\u001b[0m     type_to_dtype,\n\u001b[1;32m     31\u001b[0m )\n\u001b[1;32m     32\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims_common\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mwrappers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m backwards_not_supported\n\u001b[0;32m---> 33\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_subclasses\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfake_tensor\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FakeTensor, FakeTensorMode\n\u001b[1;32m     34\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moverrides\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m handle_torch_function, has_torch_function\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_pytree\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m tree_flatten, tree_map, tree_unflatten\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_subclasses/__init__.py:3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_subclasses\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfake_tensor\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m      4\u001b[0m     DynamicOutputShapeException,\n\u001b[1;32m      5\u001b[0m     FakeTensor,\n\u001b[1;32m      6\u001b[0m     FakeTensorMode,\n\u001b[1;32m      7\u001b[0m     UnsupportedFakeTensorException,\n\u001b[1;32m      8\u001b[0m )\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_subclasses\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfake_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CrossRefFakeMode\n\u001b[1;32m     12\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m     13\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFakeTensor\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     14\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFakeTensorMode\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCrossRefFakeMode\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     18\u001b[0m ]\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py:13\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mweakref\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ReferenceType\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_guards\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Source\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_ops\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpOverload\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_prims_common\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     16\u001b[0m     elementwise_dtypes,\n\u001b[1;32m     17\u001b[0m     ELEMENTWISE_TYPE_PROMOTION_KIND,\n\u001b[1;32m     18\u001b[0m     is_float_dtype,\n\u001b[1;32m     19\u001b[0m     is_integer_dtype,\n\u001b[1;32m     20\u001b[0m )\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/torch/_guards.py:14\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m# TODO(voz): Stolen pattern, not sure why this is the case,\u001b[39;00m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# but mypy complains.\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 14\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m  \u001b[38;5;66;03m# type: ignore[import]\u001b[39;00m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m:\n\u001b[1;32m     16\u001b[0m     log\u001b[38;5;241m.\u001b[39mwarning(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo sympy found\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/__init__.py:74\u001b[0m\n\u001b[1;32m     67\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlogic\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (to_cnf, to_dnf, to_nnf, And, Or, Not, Xor, Nand, Nor,\n\u001b[1;32m     68\u001b[0m         Implies, Equivalent, ITE, POSform, SOPform, simplify_logic, bool_map,\n\u001b[1;32m     69\u001b[0m         true, false, satisfiable)\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01massumptions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (AppliedPredicate, Predicate, AssumptionsContext,\n\u001b[1;32m     72\u001b[0m         assuming, Q, ask, register_handler, remove_handler, refine)\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (Poly, PurePoly, poly_from_expr, parallel_poly_from_expr,\n\u001b[1;32m     75\u001b[0m         degree, total_degree, degree_list, LC, LM, LT, pdiv, prem, pquo,\n\u001b[1;32m     76\u001b[0m         pexquo, div, rem, quo, exquo, half_gcdex, gcdex, invert,\n\u001b[1;32m     77\u001b[0m         subresultants, resultant, discriminant, cofactors, gcd_list, gcd,\n\u001b[1;32m     78\u001b[0m         lcm_list, lcm, terms_gcd, trunc, monic, content, primitive, compose,\n\u001b[1;32m     79\u001b[0m         decompose, sturm, gff_list, gff, sqf_norm, sqf_part, sqf_list, sqf,\n\u001b[1;32m     80\u001b[0m         factor_list, factor, intervals, refine_root, count_roots, real_roots,\n\u001b[1;32m     81\u001b[0m         nroots, ground_roots, nth_power_roots_poly, cancel, reduced, groebner,\n\u001b[1;32m     82\u001b[0m         is_zero_dimensional, GroebnerBasis, poly, symmetrize, horner,\n\u001b[1;32m     83\u001b[0m         interpolate, rational_interpolate, viete, together,\n\u001b[1;32m     84\u001b[0m         BasePolynomialError, ExactQuotientFailed, PolynomialDivisionFailed,\n\u001b[1;32m     85\u001b[0m         OperationNotSupported, HeuristicGCDFailed, HomomorphismFailed,\n\u001b[1;32m     86\u001b[0m         IsomorphismFailed, ExtraneousFactors, EvaluationFailed,\n\u001b[1;32m     87\u001b[0m         RefinementFailed, CoercionFailed, NotInvertible, NotReversible,\n\u001b[1;32m     88\u001b[0m         NotAlgebraic, DomainError, PolynomialError, UnificationFailed,\n\u001b[1;32m     89\u001b[0m         GeneratorsError, GeneratorsNeeded, ComputationFailed,\n\u001b[1;32m     90\u001b[0m         UnivariatePolynomialError, MultivariatePolynomialError,\n\u001b[1;32m     91\u001b[0m         PolificationFailed, OptionError, FlagError, minpoly,\n\u001b[1;32m     92\u001b[0m         minimal_polynomial, primitive_element, field_isomorphism,\n\u001b[1;32m     93\u001b[0m         to_number_field, isolate, round_two, prime_decomp, prime_valuation,\n\u001b[1;32m     94\u001b[0m         galois_group, itermonomials, Monomial, lex, grlex,\n\u001b[1;32m     95\u001b[0m         grevlex, ilex, igrlex, igrevlex, CRootOf, rootof, RootOf,\n\u001b[1;32m     96\u001b[0m         ComplexRootOf, RootSum, roots, Domain, FiniteField, IntegerRing,\n\u001b[1;32m     97\u001b[0m         RationalField, RealField, ComplexField, PythonFiniteField,\n\u001b[1;32m     98\u001b[0m         GMPYFiniteField, PythonIntegerRing, GMPYIntegerRing, PythonRational,\n\u001b[1;32m     99\u001b[0m         GMPYRationalField, AlgebraicField, PolynomialRing, FractionField,\n\u001b[1;32m    100\u001b[0m         ExpressionDomain, FF_python, FF_gmpy, ZZ_python, ZZ_gmpy, QQ_python,\n\u001b[1;32m    101\u001b[0m         QQ_gmpy, GF, FF, ZZ, QQ, ZZ_I, QQ_I, RR, CC, EX, EXRAW,\n\u001b[1;32m    102\u001b[0m         construct_domain, swinnerton_dyer_poly, cyclotomic_poly,\n\u001b[1;32m    103\u001b[0m         symmetric_poly, random_poly, interpolating_poly, jacobi_poly,\n\u001b[1;32m    104\u001b[0m         chebyshevt_poly, chebyshevu_poly, hermite_poly, hermite_prob_poly,\n\u001b[1;32m    105\u001b[0m         legendre_poly, laguerre_poly, apart, apart_list, assemble_partfrac_list,\n\u001b[1;32m    106\u001b[0m         Options, ring, xring, vring, sring, field, xfield, vfield, sfield)\n\u001b[1;32m    108\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mseries\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (Order, O, limit, Limit, gruntz, series, approximants,\n\u001b[1;32m    109\u001b[0m         residue, EmptySequence, SeqPer, SeqFormula, sequence, SeqAdd, SeqMul,\n\u001b[1;32m    110\u001b[0m         fourier_series, fps, difference_delta, limit_seq)\n\u001b[1;32m    112\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (factorial, factorial2, rf, ff, binomial,\n\u001b[1;32m    113\u001b[0m         RisingFactorial, FallingFactorial, subfactorial, carmichael,\n\u001b[1;32m    114\u001b[0m         fibonacci, lucas, motzkin, tribonacci, harmonic, bernoulli, bell, euler,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    133\u001b[0m         Znm, elliptic_k, elliptic_f, elliptic_e, elliptic_pi, beta, mathieus,\n\u001b[1;32m    134\u001b[0m         mathieuc, mathieusprime, mathieucprime, riemann_xi, betainc, betainc_regularized)\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/polys/__init__.py:78\u001b[0m\n\u001b[1;32m      3\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m      4\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPoly\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPurePoly\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpoly_from_expr\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparallel_poly_from_expr\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdegree\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtotal_degree\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdegree_list\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLM\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLT\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpdiv\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprem\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpquo\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfield\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxfield\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvfield\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msfield\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m     66\u001b[0m ]\n\u001b[1;32m     68\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolytools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (Poly, PurePoly, poly_from_expr,\n\u001b[1;32m     69\u001b[0m         parallel_poly_from_expr, degree, total_degree, degree_list, LC, LM,\n\u001b[1;32m     70\u001b[0m         LT, pdiv, prem, pquo, pexquo, div, rem, quo, exquo, half_gcdex, gcdex,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     75\u001b[0m         count_roots, real_roots, nroots, ground_roots, nth_power_roots_poly,\n\u001b[1;32m     76\u001b[0m         cancel, reduced, groebner, is_zero_dimensional, GroebnerBasis, poly)\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolyfuncs\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (symmetrize, horner, interpolate,\n\u001b[1;32m     79\u001b[0m         rational_interpolate, viete)\n\u001b[1;32m     81\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrationaltools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m together\n\u001b[1;32m     83\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolyerrors\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (BasePolynomialError, ExactQuotientFailed,\n\u001b[1;32m     84\u001b[0m         PolynomialDivisionFailed, OperationNotSupported, HeuristicGCDFailed,\n\u001b[1;32m     85\u001b[0m         HomomorphismFailed, IsomorphismFailed, ExtraneousFactors,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     90\u001b[0m         MultivariatePolynomialError, PolificationFailed, OptionError,\n\u001b[1;32m     91\u001b[0m         FlagError)\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/polys/polyfuncs.py:10\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolyoptions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m allowed_flags, build_options\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolytools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m poly_from_expr, Poly\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mspecialpolys\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     11\u001b[0m     symmetric_poly, interpolating_poly)\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrings\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m sring\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutilities\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m numbered_symbols, take, public\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/polys/specialpolys.py:298\u001b[0m\n\u001b[1;32m    294\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m dmp_mul(f, h, n, K), dmp_mul(g, h, n, K), h\n\u001b[1;32m    296\u001b[0m \u001b[38;5;66;03m# A few useful polynomials from Wang's paper ('78).\u001b[39;00m\n\u001b[0;32m--> 298\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrings\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ring\n\u001b[1;32m    300\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_f_0\u001b[39m():\n\u001b[1;32m    301\u001b[0m     R, x, y, z \u001b[38;5;241m=\u001b[39m ring(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx,y,z\u001b[39m\u001b[38;5;124m\"\u001b[39m, ZZ)\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/polys/rings.py:30\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolyoptions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (Domain \u001b[38;5;28;01mas\u001b[39;00m DomainOpt,\n\u001b[1;32m     27\u001b[0m                                      Order \u001b[38;5;28;01mas\u001b[39;00m OrderOpt, build_options)\n\u001b[1;32m     28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolys\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpolyutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (expr_from_dict, _dict_reorder,\n\u001b[1;32m     29\u001b[0m                                    _parallel_dict_from_expr)\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprinting\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdefaults\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DefaultPrinting\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutilities\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m public, subsets\n\u001b[1;32m     32\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutilities\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01miterables\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m is_sequence\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/printing/__init__.py:5\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;124;03m\"\"\"Printing subsystem\"\"\"\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpretty\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m pager_print, pretty, pretty_print, pprint, pprint_use_unicode, pprint_try_use_unicode\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlatex\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m latex, print_latex, multiline_latex\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmathml\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m mathml, print_mathml\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpython\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m python, print_python\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/printing/latex.py:18\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msympify\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SympifyError\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlogic\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mboolalg\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m true, BooleanTrue, BooleanFalse\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtensor\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01marray\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NDimArray\n\u001b[1;32m     20\u001b[0m \u001b[38;5;66;03m# sympy.printing imports\u001b[39;00m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprinting\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprecedence\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m precedence_traditional\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/tensor/__init__.py:4\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;124;03m\"\"\"A module to manipulate symbolic objects with indices including tensors\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \n\u001b[1;32m      3\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mindexed\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m IndexedBase, Idx, Indexed\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mindex_methods\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_contraction_structure, get_indices\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m shape\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/tensor/indexed.py:114\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlogic\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m fuzzy_bool, fuzzy_not\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msympify\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _sympify\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mspecial\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtensor_functions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m KroneckerDelta\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmultipledispatch\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dispatch\n\u001b[1;32m    116\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutilities\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01miterables\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m is_sequence, NotIterable\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.9/site-packages/sympy/functions/__init__.py:21\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01melementary\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtrigonometric\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (sin, cos, tan,\n\u001b[1;32m     18\u001b[0m         sec, csc, cot, sinc, asin, acos, atan, asec, acsc, acot, atan2)\n\u001b[1;32m     19\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01melementary\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexponential\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (exp_polar, exp, log,\n\u001b[1;32m     20\u001b[0m         LambertW)\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01melementary\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mhyperbolic\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (sinh, cosh, tanh, coth,\n\u001b[1;32m     22\u001b[0m         sech, csch, asinh, acosh, atanh, acoth, asech, acsch)\n\u001b[1;32m     23\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01melementary\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mintegers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m floor, ceiling, frac\n\u001b[1;32m     24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msympy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01melementary\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpiecewise\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (Piecewise, piecewise_fold,\n\u001b[1;32m     25\u001b[0m                                                   piecewise_exclusive)\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:1007\u001b[0m, in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:986\u001b[0m, in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:680\u001b[0m, in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap_external>:846\u001b[0m, in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap_external>:978\u001b[0m, in \u001b[0;36mget_code\u001b[0;34m(self, fullname)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap_external>:647\u001b[0m, in \u001b[0;36m_compile_bytecode\u001b[0;34m(data, name, bytecode_path, source_path)\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    default_data_collator,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    Seq2SeqTrainingArguments,\n",
    "    Seq2SeqTrainer,\n",
    "    GenerationConfig,\n",
    ")\n",
    "from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType\n",
    "from datasets import load_dataset\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"t5-large\"\n",
    "tokenizer_name_or_path = \"t5-large\"\n",
    "\n",
    "checkpoint_name = \"financial_sentiment_analysis_prefix_tuning_v1.pt\"\n",
    "text_column = \"sentence\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 8\n",
    "lr = 1e0\n",
    "num_epochs = 5\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8d0850ac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:50:04.808527Z",
     "start_time": "2023-05-30T09:49:56.953075Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 40960 || all params: 737709056 || trainable%: 0.005552324411210698\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "PeftModelForSeq2SeqLM(\n",
       "  (base_model): T5ForConditionalGeneration(\n",
       "    (shared): Embedding(32128, 1024)\n",
       "    (encoder): T5Stack(\n",
       "      (embed_tokens): Embedding(32128, 1024)\n",
       "      (block): ModuleList(\n",
       "        (0): T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (relative_attention_bias): Embedding(32, 16)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (1-23): 23 x T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_layer_norm): T5LayerNorm()\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (decoder): T5Stack(\n",
       "      (embed_tokens): Embedding(32128, 1024)\n",
       "      (block): ModuleList(\n",
       "        (0): T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (relative_attention_bias): Embedding(32, 16)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerCrossAttention(\n",
       "              (EncDecAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (2): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (1-23): 23 x T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerCrossAttention(\n",
       "              (EncDecAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (2): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_layer_norm): T5LayerNorm()\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (lm_head): Linear(in_features=1024, out_features=32128, bias=False)\n",
       "  )\n",
       "  (prompt_encoder): ModuleDict(\n",
       "    (default): PromptEmbedding(\n",
       "      (embedding): Embedding(40, 1024)\n",
       "    )\n",
       "  )\n",
       "  (word_embeddings): Embedding(32128, 1024)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# creating model\n",
    "peft_config = peft_config = PromptTuningConfig(\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=PromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=20,\n",
    "    prompt_tuning_init_text=\"What is the sentiment of this article?\\n\",\n",
    "    inference_mode=False,\n",
    "    tokenizer_name_or_path=model_name_or_path,\n",
    ")\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4ee2babf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:50:09.224782Z",
     "start_time": "2023-05-30T09:50:08.172611Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset financial_phrasebank (/data/proxem/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3a799c64a2c43258dc6166c90e2e49f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/2037 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/227 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'sentence': 'The price of the 10,000 kroon par value bonds was 9663,51 kroons in the primary issue .',\n",
       " 'label': 1,\n",
       " 'text_label': 'neutral'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# loading dataset\n",
    "dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n",
    "dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n",
    "dataset[\"validation\"] = dataset[\"test\"]\n",
    "del dataset[\"test\"]\n",
    "\n",
    "classes = dataset[\"train\"].features[\"label\"].names\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "adf9608c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:50:12.176663Z",
     "start_time": "2023-05-30T09:50:11.421273Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
      "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
      "- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.\n",
      "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
      "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/2037 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/227 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = examples[text_column]\n",
    "    targets = examples[label_column]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = tokenizer(targets, max_length=2, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = labels[\"input_ids\"]\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    model_inputs[\"labels\"] = labels\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"].shuffle()\n",
    "eval_dataset = processed_datasets[\"validation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6b3a4090",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:53:10.336984Z",
     "start_time": "2023-05-30T09:50:14.780995Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1275' max='1275' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1275/1275 02:52, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.784800</td>\n",
       "      <td>0.576933</td>\n",
       "      <td>0.559471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.648200</td>\n",
       "      <td>0.437575</td>\n",
       "      <td>0.577093</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.536200</td>\n",
       "      <td>0.397857</td>\n",
       "      <td>0.625551</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.472200</td>\n",
       "      <td>0.373160</td>\n",
       "      <td>0.643172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.452500</td>\n",
       "      <td>0.370234</td>\n",
       "      <td>0.656388</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1275, training_loss=1.3787811279296875, metrics={'train_runtime': 173.3699, 'train_samples_per_second': 58.747, 'train_steps_per_second': 7.354, 'total_flos': 344546979840000.0, 'train_loss': 1.3787811279296875, 'epoch': 5.0})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# training and evaluation\n",
    "\n",
    "\n",
    "def compute_metrics(eval_preds):\n",
    "    preds, labels = eval_preds\n",
    "    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
    "    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for pred, true in zip(preds, labels):\n",
    "        if pred.strip() == true.strip():\n",
    "            correct += 1\n",
    "        total += 1\n",
    "    accuracy = correct / total\n",
    "    return {\"accuracy\": accuracy}\n",
    "\n",
    "\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    \"out\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    learning_rate=lr,\n",
    "    num_train_epochs=num_epochs,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_strategy=\"no\",\n",
    "    report_to=[],\n",
    "    predict_with_generate=True,\n",
    "    generation_config=GenerationConfig(max_length=max_length),\n",
    ")\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    data_collator=default_data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a8de6005",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:53:13.045146Z",
     "start_time": "2023-05-30T09:53:13.035612Z"
    }
   },
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd20cd4c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:53:15.240763Z",
     "start_time": "2023-05-30T09:53:15.059304Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "164K\tt5-large_PROMPT_TUNING_SEQ_2_SEQ_LM/adapter_model.bin\r\n"
     ]
    }
   ],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "76c2fc29",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:53:25.055105Z",
     "start_time": "2023-05-30T09:53:17.797989Z"
    }
   },
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d997f1cc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T09:53:26.777030Z",
     "start_time": "2023-05-30T09:53:26.013697Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Aspocomp Group , headquartered in Helsinki , Finland , develops interconnection solutions for the electronics industry .\n",
      "{'input_ids': tensor([[   71,  7990,  7699,  1531,     3,     6,     3, 27630,    16, 29763,\n",
      "             3,     6, 16458,     3,     6,  1344,     7,  1413, 28102,  1275,\n",
      "            21,     8, 12800,   681,     3,     5,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1]])}\n",
      "tensor([[   0, 7163,    1]])\n",
      "['neutral']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 107\n",
    "inputs = tokenizer(dataset[\"validation\"][text_column][i], return_tensors=\"pt\")\n",
    "print(dataset[\"validation\"][text_column][i])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb746c1e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peft",
   "language": "python",
   "name": "peft"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}