{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "62c5865f",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c7800a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    # are we running on Google Colab?\n",
    "    import google.colab\n",
    "    !git clone -q https://github.com/teticio/audio-diffusion.git\n",
    "    %cd audio-diffusion\n",
    "    !pip install -q -r requirements.txt\n",
    "except:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b447e2c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2fc0e7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "from IPython.display import Audio\n",
    "from audiodiffusion.mel import Mel\n",
    "from audiodiffusion import AudioDiffusion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fd945bb",
   "metadata": {},
   "source": [
    "### Select model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f24046",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@markdown teticio/audio-diffusion-256                     - trained on my Spotify \"liked\" playlist\n",
    "\n",
    "#@markdown teticio/audio-diffusion-breaks-256              - trained on samples used in music\n",
    "\n",
    "#@markdown teticio/audio-diffusion-instrumental-hiphop-256 - trained on instrumental hiphop\n",
    "\n",
    "model_id = \"teticio/audio-diffusion-256\"  #@param [\"teticio/audio-diffusion-256\", \"teticio/audio-diffusion-breaks-256\", \"audio-diffusion-instrumenal-hiphop-256\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d45c36",
   "metadata": {},
   "outputs": [],
   "source": [
    "audio_diffusion = AudioDiffusion(model_id=model_id)\n",
    "mel = Mel(x_res=256, y_res=256)\n",
    "generator = torch.Generator()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "011fb5a1",
   "metadata": {},
   "source": [
    "### Run model inference to generate mel spectrogram, audios and loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b809fed5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(10):\n",
    "    seed = generator.seed()\n",
    "    print(f'Seed = {seed}')\n",
    "    generator.manual_seed(seed)\n",
    "    image, (sample_rate,\n",
    "            audio) = audio_diffusion.generate_spectrogram_and_audio(\n",
    "                generator=generator)\n",
    "    display(image)\n",
    "    display(Audio(audio, rate=sample_rate))\n",
    "    loop = AudioDiffusion.loop_it(audio, sample_rate)\n",
    "    if loop is not None:\n",
    "        display(Audio(loop, rate=sample_rate))\n",
    "    else:\n",
    "        print(\"Unable to determine loop points\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bb03e33",
   "metadata": {},
   "source": [
    "### Generate variations of audios"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80e5b5fa",
   "metadata": {},
   "source": [
    "Try playing around with `start_steps`. Values closer to zero will produce new samples, while values closer to 1,000 will produce samples more faithful to the original."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e637e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 16183389798189209330  #@param {type:\"integer\"}\n",
    "generator.manual_seed(seed)\n",
    "image, (sample_rate,\n",
    "        audio) = audio_diffusion.generate_spectrogram_and_audio(\n",
    "            generator=generator)\n",
    "display(image)\n",
    "display(Audio(audio, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0fefe28",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "start_steps = 500  #@param {type:\"slider\", min:0, max:1000, step:10}\n",
    "track = AudioDiffusion.loop_it(audio, sample_rate, loops=1)\n",
    "for variation in range(12):\n",
    "    image2, (\n",
    "        sample_rate, audio2\n",
    "    ) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
    "        raw_audio=audio,\n",
    "        start_step=start_steps)\n",
    "    display(image2)\n",
    "    display(Audio(audio2, rate=sample_rate))\n",
    "    track = np.concatenate([track, AudioDiffusion.loop_it(audio2, sample_rate, loops=1)])\n",
    "display(Audio(track, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58a876c1",
   "metadata": {},
   "source": [
    "### Generate continuations (\"out-painting\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b95d5780",
   "metadata": {},
   "outputs": [],
   "source": [
    "overlap_secs = 2  #@param {type:\"integer\"}\n",
    "start_step = 0  #@param {type:\"slider\", min:0, max:1000, step:10}\n",
    "overlap_samples = overlap_secs * sample_rate\n",
    "track = audio\n",
    "for variation in range(12):\n",
    "    image2, (\n",
    "        sample_rate, audio2\n",
    "    ) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
    "        raw_audio=audio[-overlap_samples:],\n",
    "        start_step=start_step,\n",
    "        mask_start_secs=overlap_secs)\n",
    "    display(image2)\n",
    "    display(Audio(audio2, rate=sample_rate))\n",
    "    track = np.concatenate([track, audio2[overlap_samples:]])\n",
    "    audio = audio2\n",
    "display(Audio(track, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6434d3f",
   "metadata": {},
   "source": [
    "### Remix (style transfer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0da030b2",
   "metadata": {},
   "source": [
    "Alternatively, you can start from another audio altogether, resulting in a kind of style transfer. Maintaining the same seed during generation fixes the style, while masking helps stitch consecutive segments together more smoothly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc620a80",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    # are we running on Google Colab?\n",
    "    from google.colab import files\n",
    "    audio_file = list(files.upload().keys())[0]\n",
    "except:\n",
    "    audio_file = \"/home/teticio/Music/liked/El Michels Affair - Glaciers Of Ice.mp3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a257e69",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "start_step = 500  #@param {type:\"slider\", min:0, max:1000, step:10}\n",
    "overlap_secs = 2  #@param {type:\"integer\"}\n",
    "mel.load_audio(audio_file)\n",
    "overlap_samples = overlap_secs * mel.get_sample_rate()\n",
    "slice_size = mel.x_res * mel.hop_length\n",
    "stride = slice_size - overlap_samples\n",
    "generator = torch.Generator()\n",
    "seed = generator.seed()\n",
    "print(f'Seed = {seed}')\n",
    "track = np.array([])\n",
    "not_first = 0\n",
    "for sample in range(len(mel.audio) // stride):\n",
    "    generator.manual_seed(seed)\n",
    "    audio = np.array(mel.audio[sample * stride:sample * stride + slice_size])\n",
    "    if not_first:\n",
    "        # Normalize and re-insert generated audio\n",
    "        audio[:overlap_samples] = audio2[-overlap_samples:] * np.max(\n",
    "            audio[:overlap_samples]) / np.max(audio2[-overlap_samples:])\n",
    "    _, (sample_rate,\n",
    "        audio2) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
    "            raw_audio=audio,\n",
    "            start_step=start_step,\n",
    "            generator=generator,\n",
    "            mask_start_secs=overlap_secs * not_first)\n",
    "    track = np.concatenate([track, audio2[overlap_samples * not_first:]])\n",
    "    not_first = 1\n",
    "    display(Audio(track, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "924ff9d5",
   "metadata": {},
   "source": [
    "### Fill the gap (\"in-painting\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0200264c",
   "metadata": {},
   "outputs": [],
   "source": [
    "slice = 3  #@param {type:\"integer\"}\n",
    "audio = mel.get_audio_slice(slice)\n",
    "_, (sample_rate,\n",
    "    audio2) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
    "        raw_audio=mel.get_audio_slice(slice),\n",
    "        mask_start_secs=1,\n",
    "        mask_end_secs=1)\n",
    "display(Audio(audio, rate=sample_rate))\n",
    "display(Audio(audio2, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef54cef3",
   "metadata": {},
   "source": [
    "### Compare results with random sample from training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "269ee816",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9023846",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = random.choice(ds['train'])['image']\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "492e2334",
   "metadata": {},
   "outputs": [],
   "source": [
    "audio = mel.image_to_audio(image)\n",
    "Audio(data=audio, rate=mel.get_sample_rate())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4deb47f4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "huggingface",
   "language": "python",
   "name": "huggingface"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}