{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "62c5865f",
   "metadata": {
    "id": "62c5865f"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/train_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c7800a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6c7800a6",
    "outputId": "ed18f4a9-ccea-4d7c-c82b-1749f1041f6c"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    # are we running on Google Colab?\n",
    "    import google.colab\n",
    "    !git clone -q https://github.com/teticio/audio-diffusion.git\n",
    "    %cd audio-diffusion\n",
    "    %pip install -q -r requirements.txt .\n",
    "except:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2fc0e7a",
   "metadata": {
    "id": "c2fc0e7a"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "from audiodiffusion import AudioDiffusion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "MqlpL75_mDVv",
   "metadata": {
    "id": "MqlpL75_mDVv"
   },
   "source": [
    "### Upload / specify audio files to train on\n",
    "Provide some MP3 or WAV files that will be split into samples and converted to Mel spectrograms. For a resolution of 256, the samples will be about 5 seconds long."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "jg1zAHVsmCBG",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 73
    },
    "id": "jg1zAHVsmCBG",
    "outputId": "414244c9-02b6-4ccf-cbfd-83f9022a0fc1"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    # are we running on Google Colab?\n",
    "    from google.colab import files\n",
    "    input_dir = '.'\n",
    "    files.upload();\n",
    "except:\n",
    "    input_dir = \"/home/teticio/Music/liked\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10v0RCSUu75P",
   "metadata": {
    "id": "10v0RCSUu75P"
   },
   "source": [
    "### Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "NJNeEU6ftaTM",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NJNeEU6ftaTM",
    "outputId": "6c5bed15-c821-4def-eb90-3ab1a17b3c3d"
   },
   "outputs": [],
   "source": [
    "!python scripts/audio_to_images.py \\\n",
    "  --resolution 256,256 \\\n",
    "  --input_dir {input_dir} \\\n",
    "  --output_dir data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5mGeXyJFvQCO",
   "metadata": {
    "id": "5mGeXyJFvQCO"
   },
   "source": [
    "### Train model\n",
    "The DDIM scheduler generates samples much faster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "JGnlePbLvTOH",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JGnlePbLvTOH",
    "outputId": "69b6f53e-25a3-4c59-e205-2eab42889cd8"
   },
   "outputs": [],
   "source": [
    "!python scripts/train_unet.py \\\n",
    "  --dataset_name data \\\n",
    "  --output_dir model \\\n",
    "  --num_epochs 10 \\\n",
    "  --train_batch_size 2 \\\n",
    "  --eval_batch_size 2 \\\n",
    "  --gradient_accumulation_steps 8 \\\n",
    "  --save_images_epochs 100 \\\n",
    "  --save_model_epochs 1 \\\n",
    "  --scheduler ddim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "nTMAYEtMxtt0",
   "metadata": {
    "id": "nTMAYEtMxtt0"
   },
   "source": [
    "### Generate samples with model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b294a94a",
   "metadata": {
    "id": "b294a94a"
   },
   "outputs": [],
   "source": [
    "audio_diffusion = AudioDiffusion('model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "k2bKq3aqyAIM",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 363,
     "referenced_widgets": [
      "474d4db933d54e0497da4076a7fe135b",
      "a849a3a1b46947db830a6a087411ec68",
      "378f819239274ac88d913714bc27bf06",
      "cc3b33e508744206955b26a417fbbdec",
      "6015e5a9e6774e9abf7db273bca57363",
      "629c21c68d22447185bb961e22bce4a6",
      "2d5abefbc2ed4b72aed8c4f8ddc7a00c",
      "11d1dbae00764a1c9dcc899c0b0f67dc",
      "acdb5ddc7bda411a948689787b18b21e",
      "9c4955f9d0f443a7b28ed827c5cdb37f",
      "f9a1a976d82148f8961e80c357bc2764"
     ]
    },
    "id": "k2bKq3aqyAIM",
    "outputId": "d48238fe-ae36-4736-e67b-b69e3729304a"
   },
   "outputs": [],
   "source": [
    "image, (sample_rate, audio) = audio_diffusion.generate_spectrogram_and_audio()\n",
    "display(image)\n",
    "display(Audio(audio, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "K2qAIJzg2DNK",
   "metadata": {
    "id": "K2qAIJzg2DNK"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "huggingface",
   "language": "python",
   "name": "huggingface"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "11d1dbae00764a1c9dcc899c0b0f67dc": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2d5abefbc2ed4b72aed8c4f8ddc7a00c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "378f819239274ac88d913714bc27bf06": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_11d1dbae00764a1c9dcc899c0b0f67dc",
      "max": 50,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_acdb5ddc7bda411a948689787b18b21e",
      "value": 50
     }
    },
    "474d4db933d54e0497da4076a7fe135b": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_a849a3a1b46947db830a6a087411ec68",
       "IPY_MODEL_378f819239274ac88d913714bc27bf06",
       "IPY_MODEL_cc3b33e508744206955b26a417fbbdec"
      ],
      "layout": "IPY_MODEL_6015e5a9e6774e9abf7db273bca57363"
     }
    },
    "6015e5a9e6774e9abf7db273bca57363": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "629c21c68d22447185bb961e22bce4a6": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9c4955f9d0f443a7b28ed827c5cdb37f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a849a3a1b46947db830a6a087411ec68": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_629c21c68d22447185bb961e22bce4a6",
      "placeholder": "​",
      "style": "IPY_MODEL_2d5abefbc2ed4b72aed8c4f8ddc7a00c",
      "value": "100%"
     }
    },
    "acdb5ddc7bda411a948689787b18b21e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "cc3b33e508744206955b26a417fbbdec": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_9c4955f9d0f443a7b28ed827c5cdb37f",
      "placeholder": "​",
      "style": "IPY_MODEL_f9a1a976d82148f8961e80c357bc2764",
      "value": " 50/50 [00:07&lt;00:00,  8.13it/s]"
     }
    },
    "f9a1a976d82148f8961e80c357bc2764": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}