{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# User-Defined Pruner\n\nIn :mod:`optuna.pruners`, we described how an objective function can optionally include\ncalls to a pruning feature which allows Optuna to terminate an optimization\ntrial when intermediate results do not appear promising. In this document, we\ndescribe how to implement your own pruner, i.e., a custom strategy for\ndetermining when to stop a trial.\n\n## Overview of Pruning Interface\n\nThe :func:`~optuna.study.create_study` constructor takes, as an optional\nargument, a pruner inheriting from :class:`~optuna.pruners.BasePruner`. The\npruner should implement the abstract method\n:func:`~optuna.pruners.BasePruner.prune`, which takes arguments for the\nassociated :class:`~optuna.study.Study` and :class:`~optuna.trial.Trial` and\nreturns a boolean value: :obj:`True` if the trial should be pruned and :obj:`False`\notherwise. Using the Study and Trial objects, you can access all other trials\nthrough the :func:`~optuna.study.Study.get_trial` method and, and from a trial,\nits reported intermediate values through the\n:func:`~optuna.trial.FrozenTrial.intermediate_values` (a\ndictionary which maps an integer ``step`` to a float value).\n\nYou can refer to the source code of the built-in Optuna pruners as templates for\nbuilding your own. In this document, for illustration, we describe the\nconstruction and usage of a simple (but aggressive) pruner which prunes trials\nthat are in last place compared to completed trials at the same step.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Please refer to the documentation of :class:`~optuna.pruners.BasePruner` or,\n    for example, :class:`~optuna.pruners.ThresholdPruner` or\n    :class:`~optuna.pruners.PercentilePruner` for more robust examples of pruner\n    implementation, including error checking and complex pruner-internal logic.</p></div>\n\n## An Example: Implementing ``LastPlacePruner``\n\nWe aim to optimize the ``loss`` and ``alpha`` hyperparameters for a stochastic\ngradient descent classifier (``SGDClassifier``) run on the sklearn iris dataset. We\nimplement a pruner which terminates a trial at a certain step if it is in last\nplace compared to completed trials at the same step. We begin considering\npruning after a \"warmup\" of 1 training step and 5 completed trials. For\ndemonstration purposes, we :func:`print` a diagnostic message from ``prune`` when\nit is about to return :obj:`True` (indicating pruning).\n\nIt may be important to note that the ``SGDClassifier`` score, as it is evaluated on\na holdout set, decreases with enough training steps due to overfitting. This\nmeans that a trial could be pruned even if it had a favorable (high) value on a\nprevious training set. After pruning, Optuna will take the intermediate value\nlast reported as the value of the trial.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom sklearn.datasets import load_iris\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.linear_model import SGDClassifier\n\nimport optuna\nfrom optuna.pruners import BasePruner\nfrom optuna.trial._state import TrialState\n\n\nclass LastPlacePruner(BasePruner):\n    def __init__(self, warmup_steps, warmup_trials):\n        self._warmup_steps = warmup_steps\n        self._warmup_trials = warmup_trials\n\n    def prune(self, study: \"optuna.study.Study\", trial: \"optuna.trial.FrozenTrial\") -> bool:\n        # Get the latest score reported from this trial\n        step = trial.last_step\n\n        if step:  # trial.last_step == None when no scores have been reported yet\n            this_score = trial.intermediate_values[step]\n\n            # Get scores from other trials in the study reported at the same step\n            completed_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))\n            other_scores = [\n                t.intermediate_values[step]\n                for t in completed_trials\n                if step in t.intermediate_values\n            ]\n            other_scores = sorted(other_scores)\n\n            # Prune if this trial at this step has a lower value than all completed trials\n            # at the same step. Note that steps will begin numbering at 0 in the objective\n            # function definition below.\n            if step >= self._warmup_steps and len(other_scores) > self._warmup_trials:\n                if this_score < other_scores[0]:\n                    print(f\"prune() True: Trial {trial.number}, Step {step}, Score {this_score}\")\n                    return True\n\n        return False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Lastly, let's confirm the implementation is correct with the simple hyperparameter optimization.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def objective(trial):\n    iris = load_iris()\n    classes = np.unique(iris.target)\n    X_train, X_valid, y_train, y_valid = train_test_split(\n        iris.data, iris.target, train_size=100, test_size=50, random_state=0\n    )\n\n    loss = trial.suggest_categorical(\"loss\", [\"hinge\", \"log\", \"perceptron\"])\n    alpha = trial.suggest_float(\"alpha\", 0.00001, 0.001, log=True)\n    clf = SGDClassifier(loss=loss, alpha=alpha, random_state=0)\n    score = 0\n\n    for step in range(0, 5):\n        clf.partial_fit(X_train, y_train, classes=classes)\n        score = clf.score(X_valid, y_valid)\n\n        trial.report(score, step)\n\n        if trial.should_prune():\n            raise optuna.TrialPruned()\n\n    return score\n\n\npruner = LastPlacePruner(warmup_steps=1, warmup_trials=5)\nstudy = optuna.create_study(direction=\"maximize\", pruner=pruner)\nstudy.optimize(objective, n_trials=50)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}