{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression Applied to Classification of Breast Tumors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we use logistic regression to classify breast tumors in two classes, benign or malignant.\n", "The dataset used in this short tutorial is available here: https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/. *Note: there were a few missing data (label as '?') which were replaced with zeros*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The whole documentation of the dataset can be seen in the ``breast-cancer-wisconsin.names`` file available in the link above. Nonetheless, I will briefly mention the characteristics of this dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This dataset has nine interger-valued features that biologically characterizes a given tumor, e.g., size of the cell, clump thickness, etc. Every sample in the dataset has a label (or ``class``) which indicates whether the tumor is benign or malignant. Benign samples have ``class == 2`` whereas malignant samples have ``class == 4``." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Data Visualization\n", "Let's load and visualize the dataset using Pandas" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "np.random.seed(123)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "names = ['Sample code number', 'Clump Thickness', 'Uniformity of Cell Size',\n", " 'Uniformity of Cell Shape', 'Marginal Adhesion', 'Single Epithelial Cell Size',\n", " 'Bare Nuclei', 'Bland Chromatin', 'Normal Nucleoli', 'Mitoses', 'Class']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "breast_cancer_df = pd.read_csv('breast-cancer-wisconsin.data', names=names)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Sample code numberClump ThicknessUniformity of Cell SizeUniformity of Cell ShapeMarginal AdhesionSingle Epithelial Cell SizeBare NucleiBland ChromatinNormal NucleoliMitosesClass
010000255111213112
1100294554457103212
210154253111223112
310162776881343712
410170234113213112
510171228101087109714
6101809911112103112
710185612121213112
810330782111211152
910330784211212112
1010352831111113112
1110361722111212112
1210418015333234414
1310439991111233112
14104457287510795544
1510476307464614314
1610486724111212112
1710498154111213112
181050670107764104124
1910507186111213112
201054590732105105444
211054593105536771014
2210567843111212112
2310570138451207314
2410595521111213112
2510657265234273614
2610663733211112112
2710669795111212112
2810674442111212112
2910709351131211112
....................................
66913504235101085571014
670135284831078587414
67113530923212213112
67213548402111213112
67313548405321311112
67413552601111212112
67513650754141211112
67613653281121212112
67713682675111211112
67813682731111211112
67913688822111211112
680136982110101010510101074
681137102651010104105634
68213719205111213212
6834669061111211112
6844669061111211112
6855345551111211112
6865367081111211112
6875663463111212312
6886031484111211112
6896545461111211182
6906545461113211112
691695091510105454414
6927140393111211112
6937632353111212122
6947767153111321112
6958417692111211112
6968888205101037381024
69789747148643410614
69889747148854510414
\n", "

699 rows × 11 columns

\n", "
" ], "text/plain": [ " Sample code number Clump Thickness Uniformity of Cell Size \\\n", "0 1000025 5 1 \n", "1 1002945 5 4 \n", "2 1015425 3 1 \n", "3 1016277 6 8 \n", "4 1017023 4 1 \n", "5 1017122 8 10 \n", "6 1018099 1 1 \n", "7 1018561 2 1 \n", "8 1033078 2 1 \n", "9 1033078 4 2 \n", "10 1035283 1 1 \n", "11 1036172 2 1 \n", "12 1041801 5 3 \n", "13 1043999 1 1 \n", "14 1044572 8 7 \n", "15 1047630 7 4 \n", "16 1048672 4 1 \n", "17 1049815 4 1 \n", "18 1050670 10 7 \n", "19 1050718 6 1 \n", "20 1054590 7 3 \n", "21 1054593 10 5 \n", "22 1056784 3 1 \n", "23 1057013 8 4 \n", "24 1059552 1 1 \n", "25 1065726 5 2 \n", "26 1066373 3 2 \n", "27 1066979 5 1 \n", "28 1067444 2 1 \n", "29 1070935 1 1 \n", ".. ... ... ... \n", "669 1350423 5 10 \n", "670 1352848 3 10 \n", "671 1353092 3 2 \n", "672 1354840 2 1 \n", "673 1354840 5 3 \n", "674 1355260 1 1 \n", "675 1365075 4 1 \n", "676 1365328 1 1 \n", "677 1368267 5 1 \n", "678 1368273 1 1 \n", "679 1368882 2 1 \n", "680 1369821 10 10 \n", "681 1371026 5 10 \n", "682 1371920 5 1 \n", "683 466906 1 1 \n", "684 466906 1 1 \n", "685 534555 1 1 \n", "686 536708 1 1 \n", "687 566346 3 1 \n", "688 603148 4 1 \n", "689 654546 1 1 \n", "690 654546 1 1 \n", "691 695091 5 10 \n", "692 714039 3 1 \n", "693 763235 3 1 \n", "694 776715 3 1 \n", "695 841769 2 1 \n", "696 888820 5 10 \n", "697 897471 4 8 \n", "698 897471 4 8 \n", "\n", " Uniformity of Cell Shape Marginal Adhesion Single Epithelial Cell Size \\\n", "0 1 1 2 \n", "1 4 5 7 \n", "2 1 1 2 \n", "3 8 1 3 \n", "4 1 3 2 \n", "5 10 8 7 \n", "6 1 1 2 \n", "7 2 1 2 \n", "8 1 1 2 \n", "9 1 1 2 \n", "10 1 1 1 \n", "11 1 1 2 \n", "12 3 3 2 \n", "13 1 1 2 \n", "14 5 10 7 \n", "15 6 4 6 \n", "16 1 1 2 \n", "17 1 1 2 \n", "18 7 6 4 \n", "19 1 1 2 \n", "20 2 10 5 \n", "21 5 3 6 \n", "22 1 1 2 \n", "23 5 1 2 \n", "24 1 1 2 \n", "25 3 4 2 \n", "26 1 1 1 \n", "27 1 1 2 \n", "28 1 1 2 \n", "29 3 1 2 \n", ".. ... ... ... \n", "669 10 8 5 \n", "670 7 8 5 \n", "671 1 2 2 \n", "672 1 1 2 \n", "673 2 1 3 \n", "674 1 1 2 \n", "675 4 1 2 \n", "676 2 1 2 \n", "677 1 1 2 \n", "678 1 1 2 \n", "679 1 1 2 \n", "680 10 10 5 \n", "681 10 10 4 \n", "682 1 1 2 \n", "683 1 1 2 \n", "684 1 1 2 \n", "685 1 1 2 \n", "686 1 1 2 \n", "687 1 1 2 \n", "688 1 1 2 \n", "689 1 1 2 \n", "690 1 3 2 \n", "691 10 5 4 \n", "692 1 1 2 \n", "693 1 1 2 \n", "694 1 1 3 \n", "695 1 1 2 \n", "696 10 3 7 \n", "697 6 4 3 \n", "698 8 5 4 \n", "\n", " Bare Nuclei Bland Chromatin Normal Nucleoli Mitoses Class \n", "0 1 3 1 1 2 \n", "1 10 3 2 1 2 \n", "2 2 3 1 1 2 \n", "3 4 3 7 1 2 \n", "4 1 3 1 1 2 \n", "5 10 9 7 1 4 \n", "6 10 3 1 1 2 \n", "7 1 3 1 1 2 \n", "8 1 1 1 5 2 \n", "9 1 2 1 1 2 \n", "10 1 3 1 1 2 \n", "11 1 2 1 1 2 \n", "12 3 4 4 1 4 \n", "13 3 3 1 1 2 \n", "14 9 5 5 4 4 \n", "15 1 4 3 1 4 \n", "16 1 2 1 1 2 \n", "17 1 3 1 1 2 \n", "18 10 4 1 2 4 \n", "19 1 3 1 1 2 \n", "20 10 5 4 4 4 \n", "21 7 7 10 1 4 \n", "22 1 2 1 1 2 \n", "23 0 7 3 1 4 \n", "24 1 3 1 1 2 \n", "25 7 3 6 1 4 \n", "26 1 2 1 1 2 \n", "27 1 2 1 1 2 \n", "28 1 2 1 1 2 \n", "29 1 1 1 1 2 \n", ".. ... ... ... ... ... \n", "669 5 7 10 1 4 \n", "670 8 7 4 1 4 \n", "671 1 3 1 1 2 \n", "672 1 3 1 1 2 \n", "673 1 1 1 1 2 \n", "674 1 2 1 1 2 \n", "675 1 1 1 1 2 \n", "676 1 2 1 1 2 \n", "677 1 1 1 1 2 \n", "678 1 1 1 1 2 \n", "679 1 1 1 1 2 \n", "680 10 10 10 7 4 \n", "681 10 5 6 3 4 \n", "682 1 3 2 1 2 \n", "683 1 1 1 1 2 \n", "684 1 1 1 1 2 \n", "685 1 1 1 1 2 \n", "686 1 1 1 1 2 \n", "687 1 2 3 1 2 \n", "688 1 1 1 1 2 \n", "689 1 1 1 8 2 \n", "690 1 1 1 1 2 \n", "691 5 4 4 1 4 \n", "692 1 1 1 1 2 \n", "693 1 2 1 2 2 \n", "694 2 1 1 1 2 \n", "695 1 1 1 1 2 \n", "696 3 8 10 2 4 \n", "697 4 10 6 1 4 \n", "698 5 10 4 1 4 \n", "\n", "[699 rows x 11 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "breast_cancer_df" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "features = ['Clump Thickness', 'Uniformity of Cell Size',\n", " 'Uniformity of Cell Shape', 'Marginal Adhesion', 'Single Epithelial Cell Size',\n", " 'Bare Nuclei', 'Bland Chromatin', 'Normal Nucleoli', 'Mitoses']" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "corr = []\n", "for f in features:\n", " c = breast_cancer_df[f].corr(breast_cancer_df['Class'], method='spearman')\n", " corr.append(c)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.68245186937823676,\n", " 0.85548668244535364,\n", " 0.83639412545877556,\n", " 0.7279952033877698,\n", " 0.76273086721512906,\n", " 0.81376763955180775,\n", " 0.74035036553976241,\n", " 0.74382258149235514,\n", " 0.52676617489092259]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at the distribution of the dataset:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "benign_samples = breast_cancer_df[breast_cancer_df['Class'] == 2]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "malignant_samples = breast_cancer_df[breast_cancer_df['Class'] == 4]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Percentage of benign examples: 66.0%\n" ] } ], "source": [ "print(\"Percentage of benign examples: {}%\".format(np.round(len(benign_samples) / len(breast_cancer_df) * 100)))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Percentage of malignant examples: 34.0%\n" ] } ], "source": [ "print(\"Percentage of malignant examples: {}%\".format(np.round(len(malignant_samples) / len(breast_cancer_df) * 100)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Model fitting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's use Scikit-learn to split the dataset in training set and testing set:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(breast_cancer_df.loc[:, 'Clump Thickness':'Mitoses'],\n", " breast_cancer_df['Class'] / 2 - 1, test_size=.3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that I scaled the `'Class'` label such that `0` represents benign sample and `1` represents malignant samples.\n", "This has to be done solely because of the assumptions of the logistic regression algorithm implemented in ``macaw``." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's import the ``LogisticRegression`` objective function from ``macaw``:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from macaw.objective_functions import LogisticRegression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See https://mirca.github.io/macaw/api/objective_functions.html#macaw.objective_functions.LogisticRegression\n", "for documentation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's instantiate an object from ``LogisticRegression`` passing the labels ``y_train`` and the features ``X_train``:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "logreg = LogisticRegression(y=np.array(y_train, dtype=float), X=np.array(X_train, dtype=float))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's use the method `fit` to get the maximum likelihood weights.\n", "\n", "*Note that we need to pass an initial estimate for the linear weights and bias of the `LogisiticRegression`*:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "res = logreg.fit(x0=np.zeros(X_train.shape[1] + 1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The maximum likelihood weights can accessed using the ``.x`` attribute:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.6716211 , -0.12269987, 0.22323592, 0.37896363,\n", " -0.06950043, 0.48099004, 0.65926442, 0.25699509,\n", " 0.58662442, -11.18542664])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Additionally, we can check the status of the `fit` and the number of iterations that it took to converge." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Success: parameters have not changed by 1e-06 since the previous iteration.'" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.status" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of iterations needed: 237\n" ] } ], "source": [ "print(\"Number of iterations needed: {}\".format(res.niters))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's compute the accuracy of our model using the test set. For that we can use the ``predict`` method passing the testing samples. This method outputs the class of each samples:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0.,\n", " 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0.,\n", " 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 0.,\n", " 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0.,\n", " 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0.,\n", " 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,\n", " 0., 1., 0., 1., 1., 1., 0., 0., 1., 1., 0., 0., 0.,\n", " 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0.,\n", " 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0.,\n", " 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0.,\n", " 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1.,\n", " 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0.,\n", " 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,\n", " 0., 0., 1., 1., 0., 1., 0., 0., 1., 1., 0., 0., 1.,\n", " 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0.,\n", " 0., 0.])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logreg.predict(np.array(X_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can compute the percentage of samples correctly classified:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "accuracy = np.round((np.array(y_test) == logreg.predict(np.array(X_test))).sum() / len(np.array(y_test)) * 100, decimals=5)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The accuracy of the model is 96.19048%\n" ] } ], "source": [ "print('The accuracy of the model is {}%'.format(accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Comparison against scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's compare ``macaw`` against ``scikit-learn``:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true }, "outputs": [], "source": [ "logit = LogisticRegression()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.96190476190476193" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit.score(X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Looks like** `macaw` **has a good agreement with** `sklearn` **:)!**" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## 4. Logistic Regression with L1 Regularization" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from macaw.objective_functions import L1LogisticRegression" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true }, "outputs": [], "source": [ "alpha = [.1, 1., 10., 100.]" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true }, "outputs": [], "source": [ "acc = []\n", "for a in alpha:\n", " l1logreg = L1LogisticRegression(y=np.array(y_train, dtype=float), X=np.array(X_train, dtype=float), alpha=a)\n", " res_l1 = l1logreg.fit(x0=np.zeros(X_train.shape[1] + 1) + 1e-1)\n", " accuracy = np.round((np.array(y_test) == l1logreg.predict(np.array(X_test))).sum() / len(np.array(y_test)) * 100,\n", " decimals=5)\n", " acc.append(accuracy)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[95.238100000000003,\n", " 95.714290000000005,\n", " 96.666669999999996,\n", " 62.380949999999999]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAEOCAYAAABbxmo1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFHxJREFUeJzt3X+QXeV93/H3R6sfSIYRFIIyCFyBwTIC1xFeM04nmMTg\nRpSAYzfG/MjgAIahsTEpk1A8ZUpm6ikp00nHdggeGmHcScJPYwacMQTHHlMcYmsF1IYADkZgRJ1S\nQMg1RkY/vv1jr2C12ZXuSvvsvdr7fs3saO85zznne/XM3c99znPOvakqJElqYU6vC5AkzV6GjCSp\nGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpmbm9LqDXDjrooFq2bFmvy5Ck\nvcratWtfrKpf2FW7gQ+ZZcuWMTIy0usyJGmvkuTZbtp5ukxSUxtefZ1zb/gOG159vdelqAcMGUlN\n3b52Pff/4EW+/ND6XpeiHjBkJDVTVax+YB0Aqx9Yh18tMngMGUnNfHfdy/xk02YANr62mTXPbOhx\nRZpphoykZlZ/ex2vbd4KwGubt7L6gad7XJFm2sBfXabZZcOrr3PpLQ/z2Y+u5IC3zO91OQPl419a\nw9cff2GHZfOGwvYzZFXwjSdeYNkVf7VDm5OPXsKffWx4psrUDHMko1nFSebeuXzVO1i6/z4smPvm\nn5XNW3ecgxn7eMHcOSzdfyGXr1o+YzVq5hkyu8nLMvuPk8y99fYl+3HfZSdy8oolLJw3tNO2C+cN\n8YEVS7jvsvfx9iX7zVCF6gVDZjf5jrn/OMnce4vmz+Xas4/jylOPZv7cif+8zJ87hytPPZo/Ofs4\nFs33jP1sZ8jsBt8x9ycnmfvHMUsXM39okpAZmsOxSxfPcEXqFd9G7IaJ3jEff/g/63FVg8VJ5v72\n/fWvsGXbNgACLJg3h59v3kYBW7Zt43vPb+Rdh+3f0xo1MxzJ7AbfMfeek8z97bvPvMymzdtYMHcO\nh+y/kM+euZJDOv21afM21qx7udclaoY4ktkF3zH3p+2TzH9w+/f4xuMvvBH6E1k4b4iTjj6Ya37r\nXzgHMEMe+dErDCV8YMWSN/7fTzjqIP7g9u9xz/f/kYefc75sUDiS2QXfMfcvJ5n715EH78vVH37n\nDv/v2/vr6g+/kyN/Yd8eV6iZ4qtuF3zH3P+2TzK/vmXbP1nnJHNvfPG84yddd8Z7DuOM9xw2g9Wo\nlxzJdMF3zP1t/CTzPvPmkM667ZPMknrDkJkCL8vsT04yS/3LkJkC3zH3p7GTzPdd9j5+/ZhffOPO\n86HESWaphwyZKfAdc39yklnqX04eTIGXZfYnJ5ml/mXITMGRB+/LJe8/aoc/WtvfMd+65jm+9uiP\ne1idJPWfDPrnbg0PD9fIyEivy5CkvUqStVW1yzvOnZORJDVjyEiSmjFkJEnNGDKSpGYMGUlSM4aM\nJKkZQ0aS1IwhI0lqxpCRJDVjyEiSmjFkJEnNGDKSpGYMGUlSM4aMJKmZWRkySY5IsjrJ7b2uRZIG\nWdOQSXJpkkeTPJbk9/ZgPzckeSHJoxOsW5XkySRPJbkCoKqerqoL9qR2SdKeaxYySY4FLgSOB94F\n/EaSI8e1OTjJfuOW7dCm40Zg1QTHGAKuBU4BVgBnJVkxLU9AkrTHWo5kjga+U1U/q6otwLeAD49r\ncyJwZ5IFAEkuBD4/fkdVdT/w8gTHOB54qjNyeR24GfjgND4HSdIeaBkyjwInJDkwySLgXwOHjW1Q\nVbcB9wK3JDkHOB/4yBSOsRR4bszj9cDSzjG/AKxM8umJNkxyWpLrN27cOIXDSZKmolnIVNXjwH8B\n/hq4B3gE2DpBu2uATcB1wOlV9dNpOPZLVXVxVb2tqq6epM3dVXXR4sWL9/RwkqRJNJ34r6rVVfXu\nqnofsAH4wfg2SU4AjgW+Alw1xUM8z46jo0M7yyRJfaD11WUHd/59K6PzMX85bv1K4HpG51HOAw5M\n8pkpHGINcFSSw5PMB84E7pqO2iVJe671fTJfTvL3wN3AJ6rqlXHrFwFnVNUPq2obcC7w7PidJLkJ\neBBYnmR9kgsAOhcUfJLReZ3HgVur6rF2T0eSNBWpql7X0FPDw8M1MjLS6zIkaa+SZG1VDe+q3ay8\n41+S1B8MGUlSM4aMJKkZQ0aS1IwhI0lqxpCRJDVjyEiSmjFkJEnNGDKSpGYMGUlSM4aMJKkZQ0aS\n1IwhI0lqxpCRJDVjyEiSmjFkJEnNGDKSpGYMGUlSM4aMJKkZQ0aS1IwhI0lqxpCRJDVjyEiSmjFk\nJEnNGDKSpGYMGUlSM4aMJKkZQ0aS1IwhI0lqxpCRJDVjyEiSmjFkJEnNGDKSpGYMGUlSM4aMJKkZ\nQ0aS1IwhI0lqxpCRJDVjyEiSmjFkJEnNdBUySe5IcmoSQ0mS1LVuQ+NPgbOBf0jyR0mWN6xJkjRL\ndBUyVfX1qjoHOA54Bvh6kr9Ncl6SeS0LlCTtvbo+/ZXkQOB3gI8DDwOfZTR07mtS2R5IckSS1Ulu\n73UtkjTIup2T+QrwP4FFwGlVdXpV3VJVlwD77mS7f5fksSSPJrkpyT67U2SSG5K8kOTRCdatSvJk\nkqeSXAFQVU9X1QW7cyxJ0vTpdiTzuapaUVVXV9WPx66oquGJNkiyFPgUMFxVxwJDwJnj2hycZL9x\ny46cYHc3AqsmOMYQcC1wCrACOCvJii6fkySpsW5DZkWS/bc/SHJAkt/tYru5wMIkcxkdBf3vcetP\nBO5MsqCz3wuBz4/fSVXdD7w8wf6PB57qjFxeB24GPtjNE5IktddtyFxYVa9sf1BVG4ALd7ZBVT0P\n/FfgR8CPgY1V9dfj2twG3AvckuQc4HzgI92Xz1LguTGP1wNLkxyY5AvAyiSfnmjDJKcluX7jxo1T\nOJwkaSq6DZmhJNn+oHOaav7ONkhyAKOjisOBQ4C3JPnt8e2q6hpgE3AdcHpV/bTLmiZVVS9V1cVV\n9baqunqSNndX1UWLFy/e08NJkibRbcjcw+ho46QkJwE3dZbtzMnAuqr6v1W1GbgD+JfjGyU5ATgW\n+ApwVdeVj3oeOGzM40M7yyRJfaDbkPn3wDeBf9v5+Rvg8l1s8yPgvUkWdUZBJwGPj22QZCVwPaMj\nnvOAA5N8pvvyWQMcleTwJPMZvbDgrilsL0lqaG43japqG6Ons67rdsdV9Z3OfSoPAVsYvbfm+nHN\nFgFnVNUPAZKcy+i9ODtIchPwq8BBSdYDV1XV6qrakuSTjM7rDAE3VNVj3dYoSWorVbXrRslRwNWM\nXib8xr0uVXVEu9JmxvDwcI2MjPS6DEnaqyRZO9ktLGN1e7rsi4yOYrYAvwb8D+DPd788SdIg6DZk\nFlbV3zA68nm2qv4QOLVdWZKk2aCrORng552P+f+HzhzI8+zk42QkSYLuRzKXMjpJ/yng3cBvAx9r\nVZQkaXbY5Uimc+PlR6vq94GfMnqpsSRJu7TLkUxVbQV+ZQZqkSTNMt3OyTyc5C7gNuDV7Qur6o4m\nVUmSZoVuQ2Yf4CXg/WOWFaMfFSNJ0oS6vePfeRhJ0pR1FTJJvsjoyGUHVXX+tFckSZo1uj1d9tUx\nv+8DfIh/+gVkkiTtoNvTZV8e+7jzgZUPNKlIkjRrdHsz5nhHAQdPZyGSpNmn2zmZ/8eOczL/yOh3\nzEiSNKluT5ft17oQSdLs09XpsiQfSrJ4zOP9k/xmu7IkSbNBt3MyV1XVxu0PquoV4Ko2JUmSZotu\nQ2aidt1e/ixJGlDdhsxIkj9O8rbOzx8Da1sWJkna+3UbMpcArwO3ADcDm4BPtCpKkjQ7dHt12avA\nFY1rkSTNMt1eXXZfkv3HPD4gyb3typIkzQbdni47qHNFGQBVtQHv+Jck7UK3IbMtyVu3P0iyjAk+\nlVmSpLG6vQz5PwAPJPkWEOAE4KJmVUmSZoVuJ/7vSTLMaLA8DNwJvNayMEnS3q/bD8j8OHApcCjw\nCPBe4EF2/DpmSZJ20O2czKXAe4Bnq+rXgJXAKzvfRJI06LoNmU1VtQkgyYKqegJY3q4sSdJs0O3E\n//rOfTJ3Avcl2QA8264sSdJs0O3E/4c6v/5hkm8Ci4F7mlUlSZoVpvxJylX1rRaFSJJmn27nZCRJ\nmjJDRpLUjCEjSWrGkJEkNWPISJKaMWQkSc0YMpKkZgwZSVIzhowkqRlDRpLUjCEjSWrGkJEkNWPI\nSJKaMWQkSc0YMpKkZgwZSVIzhowkqRlDRpLUzKwMmSRHJFmd5PZe1yJJg6xZyCRZnuSRMT8/SfJ7\nu7mvG5K8kOTRCdatSvJkkqeSXAFQVU9X1QV7+hwkSXumWchU1ZNV9UtV9UvAu4GfAV8Z2ybJwUn2\nG7fsyAl2dyOwavzCJEPAtcApwArgrCQrpucZSJL21EydLjsJ+GFVPTtu+YnAnUkWACS5EPj8+I2r\n6n7g5Qn2ezzwVGfk8jpwM/DBaa1ckrTbZipkzgRuGr+wqm4D7gVuSXIOcD7wkSnsdynw3JjH64Gl\nSQ5M8gVgZZJPT7RhktOSXL9x48YpHE6SNBXNQybJfOB04LaJ1lfVNcAm4Drg9Kr66Z4es6peqqqL\nq+ptVXX1JG3urqqLFi9evKeHkyRNYiZGMqcAD1XV/5loZZITgGMZna+5aor7fh44bMzjQzvLJEl9\nYCZC5iwmOFUGkGQlcD2j8yjnAQcm+cwU9r0GOCrJ4Z0R05nAXXtYryRpmjQNmSRvAT4A3DFJk0XA\nGVX1w6raBpwLjL84gCQ3AQ8Cy5OsT3IBQFVtAT7J6LzO48CtVfXY9D8TSdLuSFX1uoaeGh4erpGR\nkV6XIUl7lSRrq2p4V+1m5R3/kqT+YMhIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQ\nkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzRgykqRm\nDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJ\nasaQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJA2bDq69z7g3fYcOrrzc/liEj\nSQPm9rXruf8HL/Llh9Y3P5YhI0kDpKpY/cA6AFY/sI6qano8Q0aSBsh3173MTzZtBmDja5tZ88yG\npsczZCRpgKz+9jpe27wVgNc2b2X1A083Pd7cpnuXJPXMx7+0hq8//sIOy+YNhe1nyKrgG0+8wLIr\n/mqHNicfvYQ/+9jwtNTgSEaSZqnLV72Dpfvvw4K5b/6p37x1xzmYsY8XzJ3D0v0Xcvmq5dNWgyEj\nSbPU25fsx32XncjJK5awcN7QTtsunDfEB1Ys4b7L3sfbl+w3bTUYMpI0iy2aP5drzz6OK089mvlz\nJ/6TP3/uHK489Wj+5OzjWDR/emdRDBlJGgDHLF3M/KFJQmZoDscuXdzkuIaMJA2A769/hS3btgEQ\nYJ95c0hn3ZZt2/je8xubHNeQkaQB8N1nXmbT5m0smDuHQ/ZfyGfPXMkhnYsCNm3expp1Lzc5riEj\nSQPgkR+9wlDyxuT+rx/zi29cFDCU8PBzbW7K9D4ZSRoARx68L5e8/yjOeM9hbyzbflHArWue42uP\n/rjJcdP6c2v63fDwcI2MjPS6DEnaqyRZW1W7vGPT02WSpGYMGUlSMwM9J5PkNODFJM+OW7UYGH89\n30TLDgJebFTerkxUz0zso9ttdtVuZ+snW9fv/TIdfbK7++lmm0HsE/C1MtmyPe2Tf95Vq6oa2B/g\n+m6XT7JspN9qb72PbrfZVbudrd9b+2U6+qRlvwxin0xXv8zG18pM9cmgny67ewrLJ2vbK9NRz+7s\no9ttdtVuZ+v31n6Zrlpa9csg9gn4Wun2OE0M/NVleyLJSHVxdYVmlv3Sf+yT/jNTfTLoI5k9dX2v\nC9CE7Jf+Y5/0nxnpE0cykqRmHMlIkpoxZCRJzRgykqRmDJlGkhyRZHWS23tdyyBL8pYkX0ry35Oc\n0+t6NMrXR/9J8pud18ktSf7VdO3XkJlAkhuSvJDk0XHLVyV5MslTSa7Y2T6q6umquqBtpYNpiv3z\nYeD2qroQOH3Gix0gU+kXXx8zY4p9cmfndXIx8NHpqsGQmdiNwKqxC5IMAdcCpwArgLOSrEjyziRf\nHfdz8MyXPFBupMv+AQ4Fnus02zqDNQ6iG+m+XzQzbmTqfXJlZ/20GOjPLptMVd2fZNm4xccDT1XV\n0wBJbgY+WFVXA78xsxUOtqn0D7Ce0aB5BN9UNTXFfvn7ma1uME2lT5I8DvwR8LWqemi6avBF172l\nvPmOGEb/eC2drHGSA5N8AViZ5NOti9Ok/XMH8G+SXEf/fdzJIJiwX3x99NRkr5VLgJOB30py8XQd\nzJFMI1X1EqPnNtVDVfUqcF6v69COfH30n6r6HPC56d6vI5nuPQ8cNubxoZ1l6g/2T3+yX/rPjPaJ\nIdO9NcBRSQ5PMh84E7irxzXpTfZPf7Jf+s+M9okhM4EkNwEPAsuTrE9yQVVtAT4J3As8DtxaVY/1\nss5BZf/0J/ul//RDn/gBmZKkZhzJSJKaMWQkSc0YMpKkZgwZSVIzhowkqRlDRpLUjCEj9VCSZ5Ic\ntKdtpH5lyEiSmjFkpBmS5M4ka5M8luSiceuWJXkiyV8keTzJ7UkWjWlySZKHknw/yTs62xyf5MEk\nDyf52yTLZ/QJSV0wZKSZc35VvRsYBj6V5MBx65cDf1pVRwM/AX53zLoXq+o44Drg9zvLngBOqKqV\nwH8E/nPT6qXdYMhIM+dTSf4X8HeMfgruUePWP1dV3+78/ufAr4xZd0fn37XAss7vi4HbOl+t+9+A\nY1oULe0JQ0aaAUl+ldEvhPrlqnoX8DCwz7hm4z9IcOzjn3f+3cqb3wP1n4BvVtWxwGkT7E/qOUNG\nmhmLgQ1V9bPOnMp7J2jz1iS/3Pn9bOCBLva5/XtAfmdaqpSmmSEjzYx7gLljvkf97yZo8yTwiU6b\nAxidf9mZa4CrkzyM33KrPuVH/Ut9IMky4KudU1/SrOFIRpLUjCMZSVIzjmQkSc0YMpKkZgwZSVIz\nhowkqRlDRpLUjCEjSWrm/wMSZbcCeqb7+QAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.loglog(alpha, acc, '*', markersize=15)\n", "plt.ylabel('accuracy')\n", "plt.xlabel('alpha')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }