{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The purpose of this notebook is to validate ``macaw``'s logistic regression code. The original example is available in scikit-learn's documentation here: http://scikit-learn.org/stable/auto_examples/linear_model/plot_logistic.html#sphx-glr-auto-examples-linear-model-plot-logistic-py " ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from macaw.objective_functions import LogisticRegression\n", "from macaw.models import LogisticModel" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# generate some toy data\n", "xmin, xmax = -5, 5\n", "n_samples = 100\n", "np.random.seed(0)\n", "X = np.random.normal(size=n_samples)\n", "y = (X > 0).astype(np.float)\n", "X[X > 0] *= 4\n", "X += .3 * np.random.normal(size=n_samples)\n", "X = X[:, np.newaxis]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "loss = LogisticRegression(y=y, X=X)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "res = loss.fit(x0=np.random.rand(2))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6.90880141, -1.64912633])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.x" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAADFCAYAAACo92whAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGdVJREFUeJzt3X1sG/ed5/H3l6T4oGfJlmxLjqPEsZ3LNclerMTJptgk\nrdfr2sX2ih6Qh+K2TXYRFGiL7h/Xaw+7t7dA/7leN4dr0W16vjbY7mHbLZBtt7ki3U3k7jYN0gZ2\nAj/EdmPFseRnWyYtS6JISeR8748ZyhQtSnQ9JDXS9wUIQ5GjH78zIj+a32+o+YmqYowx8wnVuwBj\nzNJlAWGMKcsCwhhTlgWEMaYsCwhjTFkWEMaYsiwgjDFlWUAYY8qygDDGlBWp1xOvXr1a+/r66vX0\nxqxYb7311mVV7apk3boFRF9fH/v376/X0xuzYonIcKXrWhfDGFOWBYQxpiwLCGNMWRYQxpiyLCCM\nMWUtehZDRF4APgpcUtUPzPO4AF8HdgGTwKdV9W2/CzUwODjIwMAAhw4d4ujRowwODjI2NkZDQwO9\nvb10dXUxMzNDOBwmk8mQTCYREVatWsXExASpVAqAaDTK9PQ0ExMTOI5DJOK+DBoaGujr6+MTn/gE\n58+fZ+/evWSzWdavX8+9995La2srvb29bN++nU2bNgHw6quv8sILL3D69GluueUWnnnmGYA59+3Y\nsYOf//znDAwMkE6naWpqYuvWrezevXtOW6XbefbsWXp7e7njjjt444032Lt3L8lkkng8ztq1a2lv\nb8dxHNrb27nnnnvmbau03e9///sMDAyQSqVIJBK0tLSQy+Vobm7mwQcf5Kmnnpq3jdLt3LFjB9ls\nlrNnz9LQ0ICqksvlZvcPMGcbKqmtdP0bbaMaZLErSonI7wETwN+WCYhdwOdxA2Ib8HVV3bbYE/f3\n96ud5qzc4OAg3/3ud8nn8+zdu5fjx4+TyWSIxWI4jsP09DTxeJw77riDkydPMj09TVdXF+l0mvHx\ncUSE5qYmcuk0kVyOCO5fhzAgQEiEaDRKU1MTExMTJBIJ1nR34zgOFy9eJJFIsHv3btauXcvo1as8\n8fjjnDlzhuf+6q9ob2+ntbWVsbExTp06BcCGDRtobW3l/PnzHDt2jJmZGeLxONPT0ziOQygU4pFH\nHqGvr4/HH3+c2267DYCTJ0/ywx/+kPb2dpqbmxkeHub111/HcRyy2SwAV65cIRaLISLceuutJBIJ\ntmzZQjgcntNWsZMnT7Jnzx6OHDnC2NgYjuOQTCZRVTo6Oujr6yOXy3H77bfz7LPPzmnjl7/8Jc89\n99zsdp4/f56hoSEee+wxenp6+NWvfgXAQw89RDweZ2hoaLa25uZmJiYmGB0dXbC24m2emJhgeHgY\nVaWvr6+iNmatWQPx+IKvJRF5S1X7F1ypsG4ll5wTkT7gp2UC4n8D/6qqP/C+fxd4VFXPL9SmBcSN\nef755xkfH+fgwYO89tprXL58GVUlEokwNTWFqtLQ0EAoFCISiaCq5PN5Vkci3DkxwW2qrI1ECOVy\nqCqKGwwAihsQoVAIEcFxHESENWvWkE6nyefzALS0tLB161ay2SzRaJTh4WGmpqZIJBKzdb7//vsA\n3H777QAkk0kuXLiAeAHkOI77nKrE43EeeOABotEo999/PwD79u2bDTuAoaEhLl68SDabpbGxkWw2\nSz6fZ2ZmhoaGBlpbW2ltbSUSibB27do5bRXbt28fJ06cYHR0FIBsNjsbONFolJaWFlpbW1FVNm7c\nOKeNF198cc52JpNJstkssViMVatWkcvlAIhEIvT19TE4OAgw5699YZ+Vq614m4EbbmPWpz4FCwUI\nNxYQfnxQqhc4XfT9Ge++6wJCRJ4FngX3L4yp3NmzZ1m/fj2jo6Nks9nZroHjOBRCXlWZnp4mEonQ\nEImwbXKSx/J5cqqoKmHHIQukgRnAAXLeElUiIjiqIF50xGJcSacJR6Pk83miuRxbOjtxVLl05QqD\nU1N0dnaSDl0byjotgqrS2tTkfp9McgE3jKJAKBJBASefR2Zm+J1167h05Qr3e6+H0/v307FuHTmv\nhtPDw1yORBhXpauxkVQ2SzgaJT0zQzwcZlSEW9ramJycpLOkrWKn9+/nfEMDqXCYaDTKlakppkTc\n4AqFaPLaSafTxEKhOW2UbufpVIqG1lYyk5Okgaa2NgAmJydZ3dnJ+YYGANZ1ds62Udhn5Wor3mbg\nhtuYtcjRw42q6ScpVXUPsAfcI4haPnfQ9fb2MjY2Rnt7O/F4nImJCfL5PJFIBPHelIW/0gAfzGZ5\nwHsDHBbhgAjJaJTxmRkcx5k9zNdCeITDNEQihEIhcrkcoVCIe7q6uKTK1NTUbA23Pfwwo6OjtLS0\ncKi5matXr9JZ9CLeOzQEQGrjRgAGHYcDly+7XZzGxtmjkXwoREdHBxs+8AFaWlrAG7u4ODXFe+Pj\ntLe3A7Avl+PgwYOMhkJ0tbVxxetOZRyHRCLBLb299Pb0kEgkmC5pq9jFqSlee+UVhrz6RnG7KuAe\nGa1bt45bentxHIf4jh1z2jj06qtztvN4Ps/Vq1dpXb2azZs3k81mUVUSiQSNDz/Ma2NjADQ9/PBs\nG4V9Vq624m0GbriNavHjLMZZ4Jai79d79xkfbd++nVQqRU9PD11dXYTDYWa8N3s4HJ7tFmzcuJEN\nuRz9U1NEEwlebGjgxyIMhUKo128vrFvcvSzcbmpqQkRIJBJMTEwQj8dJp9OoKg899BCjo6OkUim2\nb9/OM888QyqVIpVK4TgOqVSKtrY22traZu+LxWLE4/HZ55uZmZk9Arr//vtn2yrdztHRURzHoaen\nh0gkQmdnJxMTEzQ0NJDJZGbbjMViTE5O0tPTc11bpfuvu7ubhoYG0un07D7L5/OEw2Ha2tpIJpN0\nd3df10bpdsbjccbHx7n33nvZvHkzyWSSVCrF5s2bGR0dpauri+7u7tltKN5nC/1ui9fv7u6mq6ur\n4jaqxY8xiN3A57g2SPkNVX1gsTZtDOLGFUa6Dx8+zJEjR+Y9i5GbmeFjFy/SNjnJz4H9jY3u4XE6\nTTKZBK6dxSiML1R6FqOtrY2enp55z2KcOXOG9evXzzmLUbhvvrMY/f397Nq1a8GzGOfOnaOnp+e6\nsxiJRII1a9bMnsXo6Ojg7rvvrvgsRqGdxsZGmpubyeVytLS0sG3btkXPYhRvUzab5dy5c3PGfAr7\nB5izDZWexShe/0bbqJSvg5Qi8gPgUWA1cBH4b0ADgKp+2zvN+U1gJ+5pzqdVddF3vgVElZw8Cd/7\nHjQ3wxe+AF5f1pgCXwcpVfXJRR5X4LMV1maq7fBhd3nffRYO5qbZJymXE8eBY8fc23ffXd9azLJg\nAbGcXLgAmQx0dEBXRdcDMWZBFhDLiXcKD7tSl/GJBcRy4n3M2QLC+MUCYjm5cMFd9vTUtw6zbFhA\nLBfZLIyOQiQCq1bVuxqzTFhALBcXL7rLNWsgZL9W4w97JS0XIyPu0s5eGB9ZQCwX3sVgrHth/GQB\nsVwUAqLoPyuNuVkWEMuFBYSpAguI5UD1WkB0dNS3FrOsWEAsB5OTkMtBIuH7FYXMymYBsRx4Vx+i\ntbW+dZhlxwJiObCAMFViAbEcWECYKrGAWA4sIEyVWEAsBxYQpkosIJaDiQl32dxc3zrMslNRQIjI\nThF5V0TeE5Evz/N4m4j8PxE5KCJHRORp/0s1ZVlAmCpZNCBEJAz8NfAR4C7gSRG5q2S1zwJHVfVe\n3CtgPyciUZ9rNeWk0+7Sm83KGL9UcgTxAPCeqr6vqtPA3wMfK1lHgRbvEvjNQAp3VjdTbaruB6XA\nAsL4rpKAKDf3ZrFvAv8GOAccBr6gqk5pQyLyrIjsF5H9I4V/TzY3J5Nxr2Ydj0M4XO9qzDLj1yDl\nHwAHgB7gd4Bvish1Q+qqukdV+1W1v8uuW+AP616YKqokICqZe/Np4Efqeg84CdzpT4lmQYWAsAFK\nUwWVBMQ+YJOI3OYNPD4BvFSyzingwwAisgbYArzvZ6GmjEJANDbWtw6zLFUy9V5ORD4H/DMQBl5Q\n1SMi8hnv8W8DXwH+RkQOAwJ8SVUvV7FuU2ADlKaKFg0IAFV9GXi55L5vF90+B+zwtzRTkUzGXSYS\n9a3DLEv2Scqgs4AwVWQBEXSFLoYFhKkCC4igKxxB2CClqQILiKCzLoapIguIoLOAMFVkARF0FhCm\niiwggkzVAsJUlQVEkE1Pu/+oFY3aP2qZqrCACLJs1l3aXBimSiwggmxqyl3GYvWtwyxbFhBBVjiC\nsIAwVWIBEWSFIwjrYpgqsYAIMutimCqzgAgy62KYKrOACDLrYpgqs4AIMutimCqzgAgy62KYKrOA\nCDLrYpgqs4AIMutimCqzgAgy62KYKvNl8l5vnUdF5IA3ee8v/C3TzMu6GKbKFr2qddHkvb+PO+3e\nPhF5SVWPFq3TDnwL2Kmqp0Sku1oFmyLWxTBV5tfkvU/hzqx1CkBVL/lbppmX/TenqTK/Ju/dDHSI\nyL+KyFsi8kfzNWST9/pI1Y4gTNX5NUgZAbYCu3En8v2vIrK5dCWbvNdHuRzk8+6FYiIVzX9kzA2r\n5JVVyeS9Z4CkqqaBtIi8BtwLHPelSnM9G6A0NeDX5L0/AT4oIhERaQS2Acf8LdXMYd0LUwO+TN6r\nqsdE5J+AQ4ADfEdV36lm4SuefQbC1IAvk/d6338N+Jp/pZkFWRfD1IB9kjKorIthasACIqisi2Fq\nwAIiqKyLYWrAAiKorIthasACIqisi2FqwAIiqKyLYWrAAiKorIthasACIqjsPzlNDVhABJUdQZga\nsIAIKhukNDVgARFUNkhpasACIqisi2FqwAIiiBwHpqdBBKLReldjljELiCAqPnoQqW8tZlmzgAgi\n616YGrGACCI7g2FqxAIiiOwMhqkRC4ggsi6GqRELiCCyLoapEQuIILIuhqkR3ybv9da7X0RyIvIf\n/CvRXMe6GKZGFg2Iosl7PwLcBTwpIneVWe+rwCt+F2lKWBfD1Ihfk/cCfB74B8Am7q0262KYGvFl\n8l4R6QU+Djy/UEM2ea9PrIthasSvQcr/BXxJVZ2FVrLJe31iF4sxNeLX5L39wN+L+38Bq4FdIpJT\n1X/0pUozlx1BmBqpJCBmJ+/FDYYngKeKV1DV2wq3ReRvgJ9aOFRRJuMu7QjCVJkvk/dWuUZTqtDF\nSCTqW4dZ9nybvLfo/k/ffFlmQTYGYWrEPkkZNPm8e7GYUMguFmOqzgIiaIqPHuxiMabKLCCCxroX\npoYsIILGzmCYGrKACBo7g2FqyAIiaKyLYWrIAiJoCl0MO4IwNWABETR2BGFqyAIiaGyQ0tSQBUTQ\n2CClqSELiKCxLoapIQuIoLEuhqkhC4igsS6GqSELiKCxLoapIQuIoLHPQZgasoAIElW73JypKQuI\nIJmackMiFnOvB2FMldmrLEise2FqzAIiSNJpd9nUVN86zIrhy9ycIvJJETkkIodF5A0Rudf/Uo0F\nhKk1v+bmPAk8oqp3A18B9vhdqOFaQDQ21rcOs2L4Mjenqr6hqle8b3+NO7mO8ZsdQZga82VuzhJ/\nDPxsvgdsbs6bNDnpLi0gTI34OkgpIo/hBsSX5nvc5ua8SXYEYWrMr7k5EZF7gO8AH1HVpD/lmTls\nDMLUWCVHELNzc4pIFHduzpeKVxCRDcCPgP+oqsf9L9MAMDHhLu0IwtSIX3Nz/gWwCviWN8N3TlX7\nq1f2CjU+7i5bW+tbh1kxfJmbU1X/BPgTf0szc+Ry7iBlKGRHEKZm7JOUQVE4emhutin3TM1YQASF\ndS9MHVhABMXYmLtsaalvHWZFsYAICjuCMHVgAREUV7xPsre11bcOs6JYQARFISA6O+tbh1lRLCCC\nohAQHR31rcOsKBYQQeA4FhCmLiwggmB8HPJ59zMQ0Wi9qzEriAVEEFy65C5Xr65vHWbFsYAIgosX\n3eWaNfWtw6w4FhBBYAFh6sQCIgguXHCXFhCmxiwglrrJSRgZgUjEAsLUnAXEUnfqlLtcv94NCWNq\nyAJiqTtxwl3eemt96zArkgXEUuY4cOyYe/vOO+tbi1mRLCCWshMn3OtQdnTA2rX1rsasQBYQS5Uq\nvP66e7u/364iZerCAmKpevttGB52Z/LeurXe1ZgVqqJhcRHZCXwd96rW31HV/17yuHiP7wImgU+r\n6ts+11o3g4ODDAwMcPbsWSKRCCLCzMwMY2NjHDx4kDNnzhCPx7nvvvvo6OjgwIEDDA8Pk0qlmJ6e\nRlWJRCLE43FisRhNTU3E43HGxsZIpVLMzMwQiUTI5XLkp6fZCvwBbnr/GDj6539OOBympaWFnTt3\n8qEPfYhXXnmF48ePMzU1RSwWo7u7m23btvHJT36STZs21XeHmWVDVHXhFdzJe48Dv4877d4+4ElV\nPVq0zi7g87gBsQ34uqpuW6jd/v5+3b9//8LVjYxcmyymYL56K73vt/j5oaEhXnzxRdra2pjKZnnz\nzTcBWLduHW+88QaZTIburi7y+TwjIyNEo1HC4TDjY2Pk8vk5TQoQCYdnv8/l84RFiKrSBHQBG4HC\ntDivAf9S9PPxeBxVJRqNsmXLFq5cuUIqlUJVufPOO4nFYmzevJkvfvGLFhKmLBF5q9JpKSo5gpid\nvNdrvDB579GidT4G/K26afNrEWkXkXWqev4Ga5/rF7+Ad965qSZu1si+fTw0PU08HmdoaIiP5nIA\nnHv3XT7uOCBC2PtX7IzjkE+nERHyjoMCpdEj+TwiguMFkHjL4vUu4QbDb0p/1jtyUVVGRkbI5/M0\nNzeTy+U4d+4cW7duZWRkhIGBAQsI44tKAmK+yXtLjw7KTfA7JyBE5FngWYANGzYs/syrV187/19u\nkG6++2/mvpL7hw4coHPNGrKhECdOn6bJuybk0VOnaPCOFvLekUI2HGYql0MAB2YDojQkQkWPA0yL\nkFYliTun4aX5qwLAcRxEhMnJSSKRCLFYDBEhk8mQSCS4cuUK586dW6AFYypX04/mqeoeYA+4XYxF\nf+DRR92vOro8NsbJ8XHa29s5mMmQyWQQEfavXs3MzAwAUe8aDalUihnV2b/0+Xye0i5cKBQiFAqR\n845ERGTOEcViCj/f2NhIPp8nn8+Ty+VIJBJkMhlisRg9PT0+7gGzklVyFqOSyXsrmuA3iLZv304q\nlWJ0dJTNmzeTSqVIJpM88sgjqCrpdJp4PE4kEiGfz5NIJGhubga4LhzAfYMDRLyPTYsIjuNUVEth\nsDMWi9HV1UUsFmNiYoJMJkNPTw+pVIquri62b9/u09abla6SQcoI7iDlh3Hf9PuAp1T1SNE6u4HP\ncW2Q8huq+sBC7VY0SLlEFM5inDt3jnA4jIiQy+W4evVqRWcxAMLh8HVnMcbHx0kmk3POYhTGGIpF\nIpF5z2IMDg6SzWaJx+N0dXXZWQxTEV8HKSucvPdl3HB4D/c059O/bfFL0aZNm5bcm+7pp5fVLjZL\nlF+T9yrwWX9LM8bUm32S0hhTlgWEMaYsCwhjTFkWEMaYshY9zVm1JxYZAYYrWHU1cLnK5VTC6pjL\n6lhaNUDlddyqql2VNFi3gKiUiOyv9Jyt1WF1rNQaqlWHdTGMMWVZQBhjygpCQOypdwEeq2Muq+Oa\npVADVKGOJT8GYYypnyAcQRhj6sQCwhhT1pILCBH5SxE5KyIHvK9dZdbbKSLvish7IvLlKtTxNRH5\njYgcEpEfi0h7mfWGROSwV6tv/7++2PaJ6xve44dE5D6/nrvoOW4RkX8RkaMickREvjDPOo+KyNWi\n39dfVKGOBfdxjfbFlqJtPCAiYyLypyXrVGVfiMgLInJJRN4puq9TRF4VkUFv2VHmZ2/ufaKqS+oL\n+EvgPy2yThg4AdwORIGDwF0+17EDiHi3vwp8tcx6Q8Bqn5970e3D/ff6n+FeC/dB4M0q/C7WAfd5\nt1twrwtSWsejwE+r/JpYcB/XYl/M8/u5gPuBo6rvC+D3gPuAd4ru+x/Al73bX57v9enH+2TJHUFU\naPZCuqo6DRQupOsbVX1FVXPet7/GvUpWrVSyfbMXClbVXwPtIrLOzyJU9bx60xeo6jhwDPdao0tN\n1fdFiQ8DJ1S1kk8C3zRVfQ1Ildz9MeB73u3vAf9+nh+96ffJUg2Iz3uHii+UOXQqd5HcankG9y/U\nfBQYEJG3vIvy+qGS7avpPhCRPuDfAW/O8/Dver+vn4nIv63C0y+2j2v9engC+EGZx6q9LwrW6LWr\nxl8A1syzzk3vl7rMJy8iA8B8k03+GfA88BXcF8VXgOdw36A1rUNVf+Kt82dADvi7Ms18UFXPikg3\n8KqI/MZL/GVDRJqBfwD+VFXHSh5+G9igqhPeeNE/An5ffmvJ7GMRiQJ/CPyXeR6uxb64jqqqiFTl\n8wp1CQhVreiqqiLyf4CfzvOQLxfJXawOEfk08FHgw+p16uZp46y3vCQiP8Y9rLvZF++SuVCwiDTg\nhsPfqeqPSh8vDgxVfVlEviUiq1XVt39eqmAf1/KiyR8B3lbVi/PUWfV9UeSieHPPeN2p+WZLuOn9\nsuS6GCV9x48D882csw/YJCK3eYn+BPCSz3XsBP4z8IeqOllmnSYRaSncxh3Y9GOmn0q27yXgj7wR\n/AeBq3qzExWVEBEBvgscU9X/WWadtd56iMgDuK+ppI81VLKPq74vijxJme5FtfdFiZeAT3m3PwX8\nZJ51bv59Us3R3t9yxPb/AoeBQ97GrPPu7wFeLlpvF+6o+gncLoHfdbyH23874H19u7QO3NHhg97X\nET/rmG/7gM8An/FuC/DX3uOHgf4q7IMP4nb1DhXth10ldXzO2/aDuIO5v+tzDfPu41rvC+95mnDf\n8G1F91V9X+AG0nlgBncc4Y+BVcBeYBAYADqr8T6xj1obY8pacl0MY8zSYQFhjCnLAsIYU5YFhDGm\nLAsIY0xZFhDGmLIsIIwxZf1/TI/Yq08S5+AAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(1, figsize=(4, 3))\n", "plt.scatter(X.ravel(), y, color='black', alpha=.4)\n", "X_test = np.linspace(-5, 10, 300)\n", "model = LogisticModel(X_test)\n", "plt.plot(X_test, model(res.x[0], res.x[1]), color='red', linewidth=2, alpha=.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }