{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Callback for Study.optimize\n\nThis tutorial showcases how to use & implement Optuna ``Callback`` for :func:`~optuna.study.Study.optimize`.\n\n``Callback`` is called after every evaluation of ``objective``, and\nit takes :class:`~optuna.study.Study` and :class:`~optuna.trial.FrozenTrial` as arguments, and does some work.\n\n:class:`~optuna.integration.MLflowCallback` is a great example.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Stop optimization after some trials are pruned in a row\n\nThis example implements a stateful callback which stops the optimization\nif a certain number of trials are pruned in a row.\nThe number of trials pruned in a row is specified by ``threshold``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import optuna\n\n\nclass StopWhenTrialKeepBeingPrunedCallback:\n    def __init__(self, threshold: int):\n        self.threshold = threshold\n        self._consequtive_pruned_count = 0\n\n    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:\n        if trial.state == optuna.trial.TrialState.PRUNED:\n            self._consequtive_pruned_count += 1\n        else:\n            self._consequtive_pruned_count = 0\n\n        if self._consequtive_pruned_count >= self.threshold:\n            study.stop()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This objective prunes all the trials except for the first 5 trials (``trial.number`` starts with 0).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def objective(trial):\n    if trial.number > 4:\n        raise optuna.TrialPruned\n\n    return trial.suggest_float(\"x\", 0, 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here, we set the threshold to ``2``: optimization finishes once two trials are pruned in a row.\nSo, we expect this study to stop after 7 trials.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\nimport sys\n\n# Add stream handler of stdout to show the messages\noptuna.logging.get_logger(\"optuna\").addHandler(logging.StreamHandler(sys.stdout))\n\nstudy_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)\nstudy = optuna.create_study()\nstudy.optimize(objective, n_trials=10, callbacks=[study_stop_cb])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As you can see in the log above, the study stopped after 7 trials as expected.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}