{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Union\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModel.from_pretrained(\"InstaDeepAI/segment_borzoi\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define useful functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_sequences(sequences: Union[str, List[str]]) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    One-hot encode a DNA sequence or a batch of DNA sequences.\n",
    "\n",
    "    Args:\n",
    "        sequences (Union[str, List[str]]): Either a DNA sequence or a list of DNA sequences\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: One-hot encoded\n",
    "            - If `sequences` is just one sequence (str), output shape is (seq_len, 4), seq_len being the length of a sequence\n",
    "            - If `sequences` is a list of sequences, output shape is (num_sequences, seq_len, 4)\n",
    "    \n",
    "    Example:\n",
    "        >>> sequences = [\"AC\", \"GT\"]\n",
    "        >>> encode_sequences(sequences)\n",
    "        tensor([[[1., 0., 0., 0.],\n",
    "                 [0., 1., 0., 0.]],\n",
    "\n",
    "                [[0., 0., 1., 0.],\n",
    "                 [0., 0., 0., 1.]]])\n",
    "    \"\"\"\n",
    "    one_hot_map = {\n",
    "        'a': torch.tensor([1., 0., 0., 0.]),\n",
    "        'c': torch.tensor([0., 1., 0., 0.]),\n",
    "        'g': torch.tensor([0., 0., 1., 0.]),\n",
    "        't': torch.tensor([0., 0., 0., 1.]),\n",
    "        'n': torch.tensor([0., 0., 0., 0.]),\n",
    "        'A': torch.tensor([1., 0., 0., 0.]),\n",
    "        'C': torch.tensor([0., 1., 0., 0.]),\n",
    "        'G': torch.tensor([0., 0., 1., 0.]),\n",
    "        'T': torch.tensor([0., 0., 0., 1.]),\n",
    "        'N': torch.tensor([0., 0., 0., 0.])\n",
    "    }\n",
    "\n",
    "    def encode_sequence(seq_str):\n",
    "        one_hot_list = []\n",
    "        for char in seq_str:\n",
    "            one_hot_vector = one_hot_map.get(char, torch.tensor([0.25, 0.25, 0.25, 0.25]))\n",
    "            one_hot_list.append(one_hot_vector)\n",
    "        return torch.stack(one_hot_list)\n",
    "\n",
    "    if isinstance(sequences, list):\n",
    "        return torch.stack([encode_sequence(seq) for seq in sequences])\n",
    "    else:\n",
    "        return encode_sequence(sequences)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequences = [\"A\"*524_288, \"G\"*524_288]\n",
    "one_hot_encoding = encode_sequences(sequences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model(one_hot_encoding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(preds['logits'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "genomics-research-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}