{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": "# Age Classification Model", "id": "49bedebca98dc3f9" }, { "metadata": {}, "cell_type": "markdown", "source": "## 1. Investigating dataset", "id": "6e7f00ae6117ea4" }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:19:03.326890Z", "start_time": "2025-08-31T10:18:59.681354Z" } }, "cell_type": "code", "source": [ "from datasets import load_dataset\n", "import matplotlib.pyplot as plt\n", "import random" ], "id": "7bc7f93118a4c9d3", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\Documents\\Personal Projects\\Age_Predictor\\.venv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:19:05.980916Z", "start_time": "2025-08-31T10:19:03.339417Z" } }, "cell_type": "code", "source": "ds = load_dataset(\"prithivMLmods/Age-Classification-Set\")", "id": "fa7acd0711ab46f4", "outputs": [], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:19:05.995917Z", "start_time": "2025-08-31T10:19:05.986918Z" } }, "cell_type": "code", "source": [ "labels = ds[\"train\"].features[\"label\"].names\n", "label_mapping = {i: v for i, v in enumerate(labels)}\n", "label_mapping" ], "id": "4798be5f4874f592", "outputs": [ { "data": { "text/plain": [ "{0: '0-12', 1: '13-20', 2: '21-44', 3: '45-64', 4: '65+'}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:20:45.518375Z", "start_time": "2025-08-31T10:20:45.505373Z" } }, "cell_type": "code", "source": [ "ds = ds[\"train\"]\n", "print(ds)\n", "print(len(ds))" ], "id": "ce4dc513c7e221c0", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 19016\n", "})\n", "19016\n" ] } ], "execution_count": 4 }, { "metadata": {}, "cell_type": "code", "source": "ds[0][\"label\"], ds[0][\"image\"]", "id": "163aabcd20d50a5f", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "def print_samples():\n", " fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(15, 7))\n", " axes = axes.flatten()\n", " ind = random.sample(range(len(ds)), 1)[0]\n", " for ax in axes:\n", " ax.imshow(ds[ind]['image'])\n", " ax.set_title(label_mapping[ds[ind]['label']])\n", " ind = random.sample(range(len(ds)), 1)[0]\n", " plt.tight_layout() # Adjust the layout to prevent titles and labels from overlapping\n", " plt.show()\n", "print_samples()" ], "id": "4d730d25dc440787", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "### Structure dataset folder for YOLO", "id": "29eecd7967fc11ab" }, { "metadata": {}, "cell_type": "markdown", "source": "Split data into \"train\", \"eval\", \"test\"", "id": "c29d47b3a35dcb5a" }, { "metadata": {}, "cell_type": "markdown", "source": "First we need to find indices of each age group", "id": "47f70337e70929b9" }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:22:53.361925Z", "start_time": "2025-08-31T10:22:21.683967Z" } }, "cell_type": "code", "source": [ "from tqdm import tqdm\n", "\n", "indices_by_class = {}\n", "\n", "for ind, sample in tqdm(enumerate(ds), total=len(ds), desc=\"Detecting indices of each age group\"):\n", " cls = label_mapping[sample['label']]\n", " if cls not in indices_by_class:\n", " indices_by_class[cls] = []\n", " indices_by_class[cls].append(ind)" ], "id": "4dd7a79968fd8d6", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Detecting indices of each age group: 100%|██████████| 19016/19016 [00:31<00:00, 600.61it/s]\n" ] } ], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:22:53.376930Z", "start_time": "2025-08-31T10:22:53.368930Z" } }, "cell_type": "code", "source": [ "for cls, indices in indices_by_class.items():\n", " print(f\"{cls}: {len(indices)} samples\")" ], "id": "26e3e3c24295096e", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0-12: 2193 samples\n", "13-20: 1779 samples\n", "21-44: 9999 samples\n", "45-64: 3785 samples\n", "65+: 1260 samples\n" ] } ], "execution_count": 6 }, { "metadata": {}, "cell_type": "markdown", "source": "Because the number of `21-44` group is exceeded, we randomly reduce the number of these images to 4000 samples (comparable to the second most class in the dataset).", "id": "b97969f053af0c4b" }, { "metadata": {}, "cell_type": "code", "source": [ "import random\n", "random.seed(42)\n", "\n", "num_sample_remain = 4000\n", "\n", "indices_by_class[\"21-44\"] = random.sample(indices_by_class[\"21-44\"], k=num_sample_remain)" ], "id": "c68dd70cd81d1ff2", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "len(indices_by_class[\"21-44\"])", "id": "e769d2571e888bc", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "TRAIN_RATIO = 0.8\n", "VALIDATION_RATIO = 0.1\n", "\n", "ds_indices = {\n", " 'train': [],\n", " 'val': [],\n", " 'test': []\n", "}\n", "\n", "for age, indices in indices_by_class.items():\n", " print(f\"Splitting dataset for {age} group...\")\n", "\n", " num_train_samples = int(TRAIN_RATIO * len(indices))\n", " num_validation_samples = int(VALIDATION_RATIO * len(indices))\n", "\n", " random.shuffle(indices)\n", " train_indices = indices[:num_train_samples]\n", " validation_indices = indices[num_train_samples:num_train_samples + num_validation_samples]\n", " test_indices = indices[num_train_samples + num_validation_samples:]\n", "\n", " ds_indices[\"train\"] += train_indices\n", " ds_indices[\"val\"] += validation_indices\n", " ds_indices[\"test\"] += test_indices\n", "\n", "random.shuffle(ds_indices[\"train\"])\n", "random.shuffle(ds_indices[\"val\"])\n", "random.shuffle(ds_indices[\"test\"])" ], "id": "cd898f55f8fe2b00", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import os\n", "from tqdm import tqdm\n", "ROOT = \"D:\\Documents\\Personal Projects\\Age_Predictor\"\n", "DATA_ROOT = os.path.join(ROOT, \"dataset\", \"age\")\n", "os.makedirs(DATA_ROOT, exist_ok=True)\n", "\n", "for split in ['train', 'val', 'test']:\n", " split_dir = os.path.join(DATA_ROOT, split)\n", " os.makedirs(split_dir, exist_ok=True)\n", "\n", " for idx in tqdm(ds_indices[split], total=len(ds_indices[split]), desc=f\"Processing {split} split...\"):\n", " example = ds[idx]\n", " pil_image = example['image']\n", " label = label_mapping[example['label']]\n", "\n", " # Create a directory for this class if it doesn't exist\n", " class_dir = os.path.join(split_dir, label)\n", " os.makedirs(class_dir, exist_ok=True)\n", "\n", " # Save this image to the class directory\n", " image_filename = f\"{idx}_{label}.png\"\n", " image_path = os.path.join(class_dir, image_filename)\n", " pil_image.save(image_path)" ], "id": "ab395962c474a497", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "# 2. Setup Model and Training Configurations", "id": "94822addb71826fb" }, { "metadata": {}, "cell_type": "code", "source": [ "from ultralytics import YOLO\n", "\n", "# Load a model\n", "model = YOLO(\"yolo11n-cls.pt\")" ], "id": "72fd12606ce90294", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import os\n", "ROOT = \"./\"\n", "DATA_ROOT = os.path.join(ROOT, \"dataset\", \"age\")" ], "id": "ab0b5eb9695b3dc4", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Train the model\n", "results = model.train(\n", " data=DATA_ROOT,\n", " epochs=50,\n", " imgsz=64,\n", " device=0,\n", " save=True,\n", " save_period=1, # Save checkpoint every 10 epochs\n", " project=\"Age_Detection\", # Name of the project directory where training outputs are saved.\n", " name=\"v1_epochs_10_imgsz_64\", # Name of the training run.\n", " dropout=0.1,\n", " plots=True # Generates and saves plots of training, validation metrics, and prediction examples.\n", ")\n" ], "id": "58c0d660d6058344", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": "### Test Performance", "id": "1dbb9edb4af0bd27" }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:47:23.186541Z", "start_time": "2025-08-31T10:47:20.839587Z" } }, "cell_type": "code", "source": "# !pip install scikit-learn seaborn", "id": "24230875529662bf", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: scikit-learn in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (1.7.1)\n", "Collecting seaborn\n", " Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n", "Requirement already satisfied: numpy>=1.22.0 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from scikit-learn) (2.2.6)\n", "Requirement already satisfied: scipy>=1.8.0 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from scikit-learn) (1.15.3)\n", "Requirement already satisfied: joblib>=1.2.0 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from scikit-learn) (1.5.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from scikit-learn) (3.6.0)\n", "Requirement already satisfied: pandas>=1.2 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from seaborn) (2.3.2)\n", "Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from seaborn) (3.10.5)\n", "Requirement already satisfied: contourpy>=1.0.1 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.59.1)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.9)\n", "Requirement already satisfied: packaging>=20.0 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (25.0)\n", "Requirement already satisfied: pillow>=8 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (11.3.0)\n", "Requirement already satisfied: pyparsing>=2.3.1 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.2.3)\n", "Requirement already satisfied: python-dateutil>=2.7 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from pandas>=1.2->seaborn) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from pandas>=1.2->seaborn) (2025.2)\n", "Requirement already satisfied: six>=1.5 in d:\\documents\\personal projects\\age_predictor\\.venv\\lib\\site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.17.0)\n", "Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)\n", "Installing collected packages: seaborn\n", "Successfully installed seaborn-0.13.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "[notice] A new release of pip is available: 25.0.1 -> 25.2\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] } ], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:48:42.432672Z", "start_time": "2025-08-31T10:48:40.471859Z" } }, "cell_type": "code", "source": [ "from ultralytics import YOLO\n", "\n", "model_path = r\"./Age_Detection/v1_epochs_10_imgsz_64/weights/best.pt\"\n", "model = YOLO(model_path) # load a custom model" ], "id": "83e459f62cea24f7", "outputs": [], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:50:42.200470Z", "start_time": "2025-08-31T10:50:13.391276Z" } }, "cell_type": "code", "source": [ "import os\n", "from tqdm import tqdm\n", "\n", "test_dir = os.path.join(\"../dataset\", \"age\", \"test\")\n", "ages = list(os.listdir(test_dir))\n", "\n", "results = {}\n", "for age in tqdm(ages):\n", " image_path = os.path.join(test_dir, age)\n", " results[age] = model(image_path, verbose=False)" ], "id": "782205439db613df", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [00:28<00:00, 5.76s/it]\n" ] } ], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:52:53.638348Z", "start_time": "2025-08-31T10:52:53.626345Z" } }, "cell_type": "code", "source": "results[\"0-12\"][0]", "id": "8a67918d4be6808f", "outputs": [ { "data": { "text/plain": [ "ultralytics.engine.results.Results object with attributes:\n", "\n", "boxes: None\n", "keypoints: None\n", "masks: None\n", "names: {0: '0-12', 1: '13-20', 2: '21-44', 3: '45-64', 4: '65+'}\n", "obb: None\n", "orig_img: array([[[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " ...,\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]],\n", "\n", " [[0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " ...,\n", " [0, 0, 0],\n", " [0, 0, 0],\n", " [0, 0, 0]]], shape=(640, 640, 3), dtype=uint8)\n", "orig_shape: (640, 640)\n", "path: 'D:\\\\Documents\\\\Personal Projects\\\\Age_Predictor\\\\notebooks\\\\..\\\\dataset\\\\age\\\\test\\\\0-12\\\\1001_0-12.png'\n", "probs: ultralytics.engine.results.Probs object\n", "save_dir: None\n", "speed: {'preprocess': 3.8295999984256923, 'inference': 7.545299973571673, 'postprocess': 0.0872999953571707}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 17 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:56:16.496688Z", "start_time": "2025-08-31T10:56:16.215627Z" } }, "cell_type": "code", "source": [ "images = []\n", "true_labels = []\n", "predicted_labels = []\n", "\n", "mapping = results[\"0-12\"][0].names\n", "\n", "for age in ages:\n", " for result in results[age]:\n", " img_path = result.path\n", "\n", " img_name = os.path.basename(img_path)\n", " images.append(img_name)\n", "\n", " true_label = img_name.split(\"_\")[-1].split(\".\")[0]\n", " true_labels.append(true_label)\n", "\n", " label_index = result.probs.top1\n", " predicted_label = mapping[label_index]\n", " predicted_labels.append(predicted_label)" ], "id": "a0889abe317861c1", "outputs": [], "execution_count": 18 }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-31T10:58:45.347150Z", "start_time": "2025-08-31T10:58:43.035241Z" } }, "cell_type": "code", "source": [ "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "%matplotlib inline\n", "\n", "class_names = list(mapping.values())\n", "\n", "# Accuracy is a great top-level metric to see overall correctness.\n", "accuracy = accuracy_score(true_labels, predicted_labels)\n", "print(f\"\\nOverall Model Accuracy: {accuracy:.4f}\")\n", "\n", "# A classification report provides a more detailed breakdown per class.\n", "# It shows precision, recall, and F1-score for each class.\n", "print(\"\\n--- Classification Report ---\")\n", "print(classification_report(true_labels, predicted_labels, target_names=class_names))\n", "\n", "# --- Step 3: Visualize with a Confusion Matrix ---\n", "# The confusion matrix provides a visual representation of the performance.\n", "# Each row represents the true class, and each column represents the predicted class.\n", "# This helps identify which classes the model is confusing with others.\n", "print(\"\\nGenerating Confusion Matrix...\")\n", "cm = confusion_matrix(true_labels, predicted_labels)\n", "\n", "plt.figure(figsize=(10, 8))\n", "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)\n", "plt.title('Confusion Matrix')\n", "plt.ylabel('True Label')\n", "plt.xlabel('Predicted Label')\n", "plt.show()\n", "\n", "print(\"\\n--- Evaluation Complete ---\")" ], "id": "fa6a27106dfd86c9", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Overall Model Accuracy: 0.7914\n", "\n", "--- Classification Report ---\n", " precision recall f1-score support\n", "\n", " 0-12 0.89 0.95 0.92 220\n", " 13-20 0.69 0.61 0.65 179\n", " 21-44 0.77 0.75 0.76 400\n", " 45-64 0.79 0.82 0.80 379\n", " 65+ 0.78 0.83 0.80 126\n", "\n", " accuracy 0.79 1304\n", " macro avg 0.79 0.79 0.79 1304\n", "weighted avg 0.79 0.79 0.79 1304\n", "\n", "\n", "Generating Confusion Matrix...\n" ] }, { "data": { "text/plain": [ "
" ], "image/png": "" }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "--- Evaluation Complete ---\n" ] } ], "execution_count": 19 }, { "metadata": {}, "cell_type": "code", "source": [ "from tqdm import tqdm\n", "\n", "corrects = {age: 0 for age in ages}\n", "total = {age: len(os.listdir(os.path.join(test_dir, age))) for age in ages}\n", "\n", "mapping = results[\"0-12\"][0].names\n", "\n", "for age in ages:\n", " for result in tqdm(results[age], total=total[age], desc=f\"Calculating accuracy for {age} group...\"):\n", " label_index = result.probs.top1\n", " label = mapping[label_index]\n", " if label == age:\n", " corrects[age] += 1\n", " print(f\"{age}: {corrects[age]}/{total[age]} - {corrects[age]/total[age] * 100:.2f}%\")\n" ], "id": "c3524a8dc5024372", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }