久久久久久久999_99精品久久精品一区二区爱城_成人欧美一区二区三区在线播放_国产精品日本一区二区不卡视频_国产午夜视频_欧美精品在线观看免费

 找回密碼
 立即注冊

QQ登錄

只需一步,快速開始

搜索
查看: 2348|回復: 0
打印 上一主題 下一主題
收起左側

Python隨機森林例子 源碼分享

[復制鏈接]
跳轉到指定樓層
樓主
  1.     "#測試gini\n",
  2.     "gini=calGini((l,r),classLabels)\n",
  3.     "print(gini)\n"
  4.    ]
  5.   },
  6.   {
  7.    "cell_type": "code",
  8.    "execution_count": 19,
  9.    "metadata": {},
  10.    "outputs": [],
  11.    "source": [
  12.     "def getBestSplit(dataSet,featureNumbers):\n",
  13.     "    '''\n",
  14.     "    對于一個數據集,選擇featureNumber個特征進行簡單劃分,得到最好的特征和劃分結果\n",
  15.     "    args:\n",
  16.     "      dataSet:數據集,類型:list\n",
  17.     "      featureNumbers:選擇的特征值數,類型:int\n",
  18.     "      classLabels:所有分類,類型:list\n",
  19.     "    ''' \n",
  20.     "    \n",
  21.     "    #樣本數\n",
  22.     "    m=len(dataSet)\n",
  23.     "    if m==0:\n",
  24.     "        return None\n",
  25.     "    #樣本特征值數+1(因為最后有一個標簽)\n",
  26.     "    totalColumnNumber=len(dataSet[0])\n",
  27.     "    #隨機選擇的特征索引\n",
  28.     "    randomSelectedFeatures=[]\n",
  29.     "    \n",
  30.     "    \n",
  31.     "    \n",
  32.     "    #選擇數目必須在特征數目范圍內\n",
  33.     "    if totalColumnNumber-1>=featureNumbers:        \n",
  34.     "        #借助這個變量防止選擇重復的特征進入\n",
  35.     "        indexList=list(range(totalColumnNumber-1))            \n",
  36.     "        for j in range(featureNumbers):\n",
  37.     "            #索引序列長度\n",
  38.     "            leftSize=len(indexList)\n",
  39.     "            #隨機數\n",
  40.     "            randIndex=random.randrange(leftSize)\n",
  41.     "            #索引學列隨機數處數據彈出,放入選擇特征列表\n",
  42.     "            origIndex=indexList.pop(randIndex)\n",
  43.     "            #存入的是原始數據特征索引\n",
  44.     "            randomSelectedFeatures.append(origIndex)\n",
  45.     "    else:\n",
  46.     "        randomSelectedFeatures=range(totalColumnNumber-1)#特征全部被選擇\n",
  47.     "    \n",
  48.     "    \n",
  49.     "   # print(\"current select features\")\n",
  50.     "   # print(randomSelectedFeatures)\n",
  51.     "\n",
  52.     "    #當前數據集的標簽序列\n",
  53.     "    class_values=list(set(item[-1] for item in dataSet))\n",
  54.     "    \n",
  55.     "    #對于每個特征以及每個特征值進行簡單劃分\n",
  56.     "    #保留最小的基尼系數\n",
  57.     "    minGini=9999\n",
  58.     "    #存入最好的信息\n",
  59.     "    bestInfor={}\n",
  60.     "    #外層循環,對于每個特征\n",
  61.     "    for index in randomSelectedFeatures:\n",
  62.     "        #內層循環對于每個特征值\n",
  63.     "        tempFeatureValueList=list(set(item[index] for item in dataSet))\n",
  64.     "        #print(len(tempFeatureValueList))\n",
  65.     "        for tempValue in tempFeatureValueList:\n",
  66.     "            #簡單分類\n",
  67.     "            groups=simpleSplit(dataSet,index,tempValue)            \n",
  68.     "            #print(\"currentIndex:%d,CurrentTempValue:%f\"%(index,tempValue))\n",
  69.     "            #計算基尼系數\n",
  70.     "            gini=calGini(groups,class_values)\n",
  71.     "            #print(\"computed gini:\",gini)            \n",
  72.     "            if gini<minGini:\n",
  73.     "                minGini=gini\n",
  74.     "                #保存目前最后的信息\n",
  75.     "                bestInfor[\"index\"]=index#存入原來索引                \n",
  76.     "                bestInfor[\"indexValue\"]=tempValue\n",
  77.     "                bestInfor[\"groups\"]=groups\n",
  78.     "                bestInfor[\"gini\"]=gini\n",
  79.     "                \n",
  80.     "    return bestInfor"
  81.    ]
  82.   },
  83.   {
  84.    "cell_type": "code",
  85.    "execution_count": 20,
  86.    "metadata": {},
  87.    "outputs": [
  88.     {
  89.      "name": "stdout",
  90.      "output_type": "stream",
  91.      "text": [
  92.       "52 0.017\n"
  93.      ]
  94.     }
  95.    ],
  96.    "source": [
  97.     "#測試最好分類函數\n",
  98.     "bestInfor=getBestSplit(dataSet,3)\n",
  99.     "print(bestInfor[\"index\"],bestInfor[\"indexValue\"])"
  100.    ]
  101.   },
  102.   {
  103.    "cell_type": "code",
  104.    "execution_count": 21,
  105.    "metadata": {},
  106.    "outputs": [],
  107.    "source": [
  108.     "def terminalLabel(subSet):\n",
  109.     "    '''\n",
  110.     "    樹葉點對應的標簽\n",
  111.     "    args:\n",
  112.     "      subSet:當前數據集,最后列是標簽列,類型:list\n",
  113.     "    returns:\n",
  114.     "      當前列中最多的標簽,類型:原標簽類型\n",
  115.     "    '''\n",
  116.     "    #得到最后一列\n",
  117.     "    labelList=[item[-1] for item in subSet]\n",
  118.     "    #max函數,key后是函數,代表對前面的進行那種運算,這里是技術\n",
  119.     "    #max返回值是第一個參數,這里set是把labelList轉換成集合,即去掉重復項\n",
  120.     "    #key:相當于循環調用labelList.count(set(labelList))中的每個元素,然后max取得最大值\n",
  121.     "    #返回set(labelList)中對應最大的那個標簽\n",
  122.     "    return max(set(labelList), key=labelList.count)   # 輸出 subSet 中出現次數較多的標簽 \n",
  123.     "\n",
  124.     "    #下面的寫法也是成立的,利用lambda表達式,表達式中x從全面取,這種寫法可能更好理解些\n",
  125.     "    #return max(set(labelList), key=lambda x:labelList.count(x)) "
  126.    ]
  127.   },
  128.   {
  129.    "cell_type": "code",
  130.    "execution_count": 22,
  131.    "metadata": {},
  132.    "outputs": [
  133.     {
  134.      "name": "stdout",
  135.      "output_type": "stream",
  136.      "text": [
  137.       "R\n"
  138.      ]
  139.     }
  140.    ],
  141.    "source": [
  142.     "#測試\n",
  143.     "label=terminalLabel(l)\n",
  144.     "print(label)"
  145.    ]
  146.   },
  147.   {
  148.    "cell_type": "code",
  149.    "execution_count": 23,
  150.    "metadata": {},
  151.    "outputs": [],
  152.    "source": [
  153.     "#對得到的最好分類信息進行分割\n",
  154.     "def split(node, max_depth, min_size, n_features, depth):  # 創建子分割器 遞歸分類 直到分類結束\n",
  155.     "    '''\n",
  156.     "    :param node:        節點,類型:字典\n",
  157.     "                    bestInfor[\"index\"]=index#存入原來索引                \n",
  158.     "                    bestInfor[\"indexValue\"]=tempValue\n",
  159.     "                    bestInfor[\"groups\"]=groups\n",
  160.     "                    bestInfor[\"gini\"]=gini\n",
  161.     "    :param max_depth:   最大深度,int\n",
  162.     "    :param min_size:    最小,int\n",
  163.     "    :param n_features:  特征選取個數,int\n",
  164.     "    :param depth:       深度,int\n",
  165.     "    :return:\n",
  166.     "    '''\n",
  167.     "    left, right = node['groups']\n",
  168.     "    del (node['groups'])\n",
  169.     "\n",
  170.     "    if not left or not right:  # 如果只有一個子集\n",
  171.     "        node['left'] = node['right'] = terminalLabel(left + right)  # 投票出類型\n",
  172.     "        return\n",
  173.     "\n",
  174.     "    if depth >= max_depth:  # 如果即將超過\n",
  175.     "        node['left'], node['right'] = terminalLabel(left), terminalLabel(right)  # 投票出類型\n",
  176.     "        return\n",
  177.     "\n",
  178.     "    if len(left) <= min_size:  # 處理左子集\n",
  179.     "        node['left'] = terminalLabel(left)\n",
  180.     "    else:\n",
  181.     "        node['left'] = getBestSplit(left, n_features)  # node['left']是一個字典,形式為{'index':b_index, 'value':b_value, 'groups':b_groups},所以node是一個多層字典\n",
  182.     "        split(node['left'], max_depth, min_size, n_features, depth + 1)  # 遞歸,depth+1計算遞歸層數\n",
  183.     "\n",
  184.     "    if len(right) <= min_size:  # 處理右子集\n",
  185.     "        node['right'] = terminalLabel(right)\n",
  186.     "    else:\n",
  187.     "        node['right'] = getBestSplit(right, n_features)\n",
  188.     "        split(node['right'], max_depth, min_size, n_features, depth + 1)\n",
  189.     "        "
  190.    ]
  191.   },
  192.   {
  193.    "cell_type": "code",
  194.    "execution_count": 24,
  195.    "metadata": {},
  196.    "outputs": [],
  197.    "source": [
  198.     "#構建一個決策樹\n",
  199.     "def buildTree(train, max_depth, min_size, n_features):\n",
  200.     "    '''\n",
  201.     "    創建一個決策樹\n",
  202.     "    :param train:       訓練數據集\n",
  203.     "    :param max_depth:   決策樹深度不能太深 不然容易導致過擬合\n",
  204.     "    :param min_size:    葉子節點的大小\n",
  205.     "    :param n_features:  選擇的特征的個數\n",
  206.     "    :return\n",
  207.     "        root    返回決策樹\n",
  208.     "    '''\n",
  209.     "    root = getBestSplit(train, n_features)  # 獲取樣本數據集\n",
  210.     "    split(root, max_depth, min_size, n_features, 1)  # 進行樣本分割,構架決策樹\n",
  211.     "    return root  # 返回決策樹\n"
  212.    ]
  213.   },
  214.   {
  215.    "cell_type": "code",
  216.    "execution_count": 25,
  217.    "metadata": {},
  218.    "outputs": [
  219.     {
  220.      "name": "stdout",
  221.      "output_type": "stream",
  222.      "text": [
  223.       "{'index': 55, 'indexValue': 0.0114, 'gini': 0.0, 'left': {'index': 35, 'indexValue': 0.2288, 'gini': 0.0, 'left': 'R', 'right': {'index': 33, 'indexValue': 0.2907, 'gini': 0.0, 'left': 'R', 'right': {'index': 58, 'indexValue': 0.0057, 'gini': 0.0, 'left': {'index': 12, 'indexValue': 0.0493, 'gini': 0.0, 'left': 'R', 'right': 'R'}, 'right': 'R'}}}, 'right': {'index': 54, 'indexValue': 0.0063, 'gini': 0.0, 'left': {'index': 21, 'indexValue': 0.8384, 'gini': 0.0, 'left': 'M', 'right': 'M'}, 'right': {'index': 32, 'indexValue': 0.558, 'gini': 0.0, 'left': 'M', 'right': {'index': 58, 'indexValue': 0.0332, 'gini': 0.0, 'left': 'M', 'right': 'M'}}}}\n"
  224.      ]
  225.     }
  226.    ],
  227.    "source": [
  228.     "#測試決策樹\n",
  229.     "#選擇一個子集\n",
  230.     "s=putBackSample(dataSet,10)\n",
  231.     "tempTree=buildTree(s,10,1,3)\n",
  232.     "print(tempTree)"
  233.    ]
  234.   },
  235.   {
  236.    "cell_type": "code",
  237.    "execution_count": 26,
  238.    "metadata": {},
  239.    "outputs": [],
  240.    "source": [
  241.     "#根據決策樹進行預測\n",
  242.     "def predict(node, row):   # 預測模型分類結果\n",
  243.     "    '''\n",
  244.     "    在當前節點進行預測,row是待預測樣本\n",
  245.     "    args:\n",
  246.     "       node:樹節點\n",
  247.     "       row:待分類樣本\n",
  248.     "    return:\n",
  249.     "       分類標簽\n",
  250.     "    '''\n",
  251.     "    if row[node['index']] < node['indexValue']:\n",
  252.     "        if isinstance(node['left'], dict):       # isinstance 是 Python 中的一個內建函數。是用來判斷一個對象是否是一個已知的類型。\n",
  253.     "            return predict(node['left'], row)\n",
  254.     "        else:\n",
  255.     "            return node['left']\n",
  256.     "    else:\n",
  257.     "        if isinstance(node['right'], dict):\n",
  258.     "            return predict(node['right'], row)\n",
  259.     "        else:\n",
  260.     "            return node['right']"
  261.    ]
  262.   },
  263.   {
  264.    "cell_type": "code",
  265.    "execution_count": 27,
  266.    "metadata": {},
  267.    "outputs": [
  268.     {
  269.      "name": "stdout",
  270.      "output_type": "stream",
  271.      "text": [
  272.       "R R\n"
  273.      ]
  274.     }
  275.    ],
  276.    "source": [
  277.     "#測試下\n",
  278.     "label=predict(tempTree,s[0])\n",
  279.     "print(label,s[0][-1])"
  280.    ]
  281.   },
  282.   {
  283.    "cell_type": "code",
  284.    "execution_count": 28,
  285.    "metadata": {},
  286.    "outputs": [],
  287.    "source": [
  288.     "#多個樹的決策,多數服從少數\n",
  289.     "def baggingPredict(trees, row):\n",
  290.     "    \"\"\"\n",
  291.     "    多個樹的決策,多數服從少數\n",
  292.     "    Args:\n",
  293.     "        trees           決策樹的集合\n",
  294.     "        row             測試數據集的每一行數據\n",
  295.     "    Returns:\n",
  296.     "        返回隨機森林中,決策樹結果出現次數做大的\n",
  297.     "    \"\"\"\n",
  298.     "\n",
  299.     "    # 使用多個決策樹trees對測試集test的第row行進行預測,再使用簡單投票法判斷出該行所屬分類\n",
  300.     "    predictions = [predict(tree, row) for tree in trees]\n",
  301.     "    return max(set(predictions), key=predictions.count)\n"
  302.    ]
  303.   },
  304.   {
  305.    "cell_type": "code",
  306.    "execution_count": 29,
  307.    "metadata": {},
  308.    "outputs": [],
  309.    "source": [
  310.     "def subSample(dataSet, ratio):  \n",
  311.     "    '''\n",
  312.     "    按比例隨機抽取數據,有重復抽樣\n",
  313.     "    args:\n",
  314.     "      dataSet:數據集,類型:list\n",
  315.     "      ratio:0-1之間的數\n",
  316.     "    '''\n",
  317.     "    if ratio<0.0:\n",
  318.     "        return None\n",
  319.     "    if ratio>=1:\n",
  320.     "        return dataSet\n",
  321.     "    sampleNumber=int(len(dataSet)*ratio)\n",
  322.     "    subSet=putBackSample(dataSet,sampleNumber)\n",
  323.     "    return subSet"
  324.    ]
  325.   },
  326.   {
  327.    "cell_type": "code",
  328.    "execution_count": 30,
  329.    "metadata": {},
  330.    "outputs": [
  331.     {
  332.      "name": "stdout",
  333.      "output_type": "stream",
  334.      "text": [
  335.       "41\n"
  336.      ]
  337.     }
  338.    ],
  339.    "source": [
  340.     "#測試\n",
  341.     "subSet=subSample(dataSet,0.2)\n",
  342.     "print(len(subSet))"
  343.    ]
  344.   },
  345.   {
  346.    "cell_type": "code",
  347.    "execution_count": 31,
  348.    "metadata": {},
  349.    "outputs": [],
  350.    "source": [
  351.     "#隨機森林主函數\n",
  352.     "def buildRandomForest(train, max_depth=10, min_size=1, sample_size=0.2, n_trees=10, n_features=3):\n",
  353.     "    \"\"\"\n",
  354.     "    random_forest(評估算法性能,返回模型得分)\n",
  355.     "    Args:\n",
  356.     "        train           訓練數據集,類型:list        \n",
  357.     "        max_depth       決策樹深度不能太深,不然容易導致過擬合\n",
  358.     "        min_size        葉子節點的大小\n",
  359.     "        sample_size     訓練數據集的樣本比例,0,1之間的數\n",
  360.     "        n_trees         決策樹的個數\n",
  361.     "        n_features      選取的特征的個數\n",
  362.     "    Returns:\n",
  363.     "        trees:樹序列\n",
  364.     "    \"\"\"\n",
  365.     "\n",
  366.     "    trees = list()\n",
  367.     "    # n_trees 表示決策樹的數量\n",
  368.     "    for i in range(n_trees):\n",
  369.     "        # 隨機抽樣的訓練樣本, 隨機采樣保證了每棵決策樹訓練集的差異性\n",
  370.     "        sample = subSample(train, sample_size)\n",
  371.     "        # 創建一個決策樹\n",
  372.     "        tree = buildTree(sample, max_depth, min_size, n_features)\n",
  373.     "        trees.append(tree)\n",
  374.     "    return trees\n",
  375.     "  \n"
  376.    ]
  377.   },
  378.   {
  379.    "cell_type": "code",
  380.    "execution_count": 32,
  381.    "metadata": {},
  382.    "outputs": [],
  383.    "source": [
  384.     "def predictByForest(trees,test):\n",
  385.     "    '''\n",
  386.     "    predictions     每一行的預測結果,bagging 預測最后的分類結果\n",
  387.     "    '''\n",
  388.     "    # 每一行的預測結果,bagging 預測最后的分類結果\n",
  389.     "    predictions = [baggingPredict(trees, row) for row in test]\n",
  390.     "    return predictions"
  391.    ]
  392.   },
  393.   {
  394.    "cell_type": "code",
  395.    "execution_count": 33,
  396.    "metadata": {},
  397.    "outputs": [],
  398.    "source": [
  399.     "def calQuota(predictions,labelClass,OrigClassLabels):\n",
  400.     "    '''\n",
  401.     "    計算分類指標\n",
  402.     "    args:\n",
  403.     "      predictions:預測值,類型:list\n",
  404.     "      labelClass:真實標簽,類型:list\n",
  405.     "      OrigClassLabels:數據可能的標簽庫,一個正例一個負例標簽\n",
  406.     "    '''\n",
  407.     "    \n",
  408.     "    Pos=OrigClassLabels[0]\n",
  409.     "    Nev=OrigClassLabels[1]    \n",
  410.     "    #真正例   \n",
  411.     "    #TP=len([item for item in labelClass if item==Pos and predictions[labelClass.index(item)]==Pos])\n",
  412.     "    TP=0\n",
  413.     "    TN=0\n",
  414.     "    FP=0\n",
  415.     "    FN=0\n",
  416.     "    for j in range(len(predictions)):        \n",
  417.     "        if predictions[j]==Pos and  labelClass[j]==Pos:\n",
  418.     "            TP+=1\n",
  419.     "        if predictions[j]==Nev and  labelClass[j]==Nev:\n",
  420.     "            TN+=1\n",
  421.     "        if predictions[j]==Pos and  labelClass[j]==Nev:\n",
  422.     "            FP+=1\n",
  423.     "        if predictions[j]==Nev and  labelClass[j]==Pos:\n",
  424.     "            FN+=1\n",
  425.     "#     #真負例,下面的做法不行,原因是index可能得到不同的索引\n",
  426.     "#     TN=len([item for item in labelClass if item==Nev and predictions[labelClass.index(item)]==Nev])\n",
  427.     "#     #偽正例\n",
  428.     "#     FP=len([item for item in labelClass if item==Nev and predictions[labelClass.index(item)]==Pos])\n",
  429.     "#     #偽負例\n",
  430.     "#     FN=len([item for item in labelClass if item==Pos and predictions[labelClass.index(item)]==Nev])\n",
  431.     "\n",
  432.     "    #Recall,TruePosProp=TP/(TP+FN)#識別的正例占整個正例的比率\n",
  433.     "    #FalsPosProp=FP/(FP+TN)#識別的正例占整個負例的比率\n",
  434.     "    #Precition=TP/(TP+FP)#識別的正確正例占識別出所有正例的比率\n",
  435.     "    \n",
  436.     "    return TP,TN,FP,FN"
  437.    ]
  438.   },
  439.   {
  440.    "cell_type": "code",
  441.    "execution_count": 34,
  442.    "metadata": {},
  443.    "outputs": [],
  444.    "source": [
  445.     "#測試下:\n",
  446.     "trees=buildRandomForest(dataSet)\n",
  447.     "testSet=nonPutBackSample(dataSet,100)\n",
  448.     "prediction=predictByForest(trees,testSet)\n"
  449.    ]
  450.   },
  451.   {
  452.    "cell_type": "code",
  453.    "execution_count": 35,
  454.    "metadata": {},
  455.    "outputs": [
  456.     {
  457.      "name": "stdout",
  458.      "output_type": "stream",
  459.      "text": [
  460.       "(44, 39, 12, 5)\n"
  461.      ]
  462.     }
  463.    ],
  464.    "source": [
  465.     "labelClass=[item[-1] for item in testSet]\n",
  466.     "\n",
  467.     "tp=calQuota(prediction,labelClass,list(classLabels))\n",
  468.     "print(tp)"
  469.    ]
  470.   },
  471.   {
  472.    "cell_type": "code",
  473.    "execution_count": 36,
  474.    "metadata": {},
  475.    "outputs": [],
  476.    "source": [
  477.     "def accuracy( predicted,actual):  \n",
  478.     "    correct = 0\n",
  479.     "    for i in range(len(actual)):\n",
  480.     "        if actual[i] == predicted[i]:\n",
  481.     "            correct += 1\n",
  482.     "    return correct / float(len(actual)) * 100.0\n"
  483.    ]
  484.   },
  485.   {
  486.    "cell_type": "code",
  487.    "execution_count": 37,
  488.    "metadata": {},
  489.    "outputs": [
  490.     {
  491.      "name": "stdout",
  492.      "output_type": "stream",
  493.      "text": [
  494.       "83.0\n"
  495.      ]
  496.     }
  497.    ],
  498.    "source": [
  499.     "a=accuracy(prediction,labelClass)\n",
  500.     "print(a)"
  501.    ]
  502.   },
  503.   {
  504.    "cell_type": "code",
  505.    "execution_count": 38,
  506.    "metadata": {},
  507.    "outputs": [],
  508.    "source": [
  509.     "def createCrossValideSets(trainSet,n_folds,bPutBack=True):\n",
  510.     "    '''\n",
  511.     "    產生交叉驗證數據集\n",
  512.     "    Args:\n",
  513.     "        dataset     原始數據集       \n",
  514.     "        n_folds     數據的份數,數據集交叉驗證的份數,采用無放回抽取\n",
  515.     "        bPutBack    是否放回\n",
  516.     "    '''\n",
  517.     "    subSetsList=[]\n",
  518.     "    subLen=int(len(trainSet)/n_folds)\n",
  519.     "    if bPutBack:\n",
  520.     "        for j in range(n_folds):\n",
  521.     "            subSet=putBackSample(trainSet,subLen)\n",
  522.     "            subSetsList.append(subSet)\n",
  523.     "    else:\n",
  524.     "        for j in range(n_folds):\n",
  525.     "            subSet=nonPutBackSample(trainSet,subLen)\n",
  526.     "            subSetsList.append(subSet)\n",
  527.     "    return subSetsList"
  528.    ]
  529.   },
  530.   {
  531.    "cell_type": "code",
  532.    "execution_count": 39,
  533.    "metadata": {},
  534.    "outputs": [],
  535.    "source": [
  536.     "def randomForest(trainSet,testSet,max_depth=10, min_size=1, sample_size=0.2, n_trees=10, n_features=3):\n",
  537.     "    '''\n",
  538.     "    構造隨機森林并測試\n",
  539.     "     Args:\n",
  540.     "        train           訓練數據集,類型:list        \n",
  541.     "        testSet         測試集,類型:list\n",
  542.     "        max_depth       決策樹深度不能太深,不然容易導致過擬合\n",
  543.     "        min_size        葉子節點的大小\n",
  544.     "        sample_size     訓練數據集的樣本比例,0,1之間的數\n",
  545.     "        n_trees         決策樹的個數\n",
  546.     "        n_features      選取的特征的個數\n",
  547.     "    Returns:\n",
  548.     "        predition       測試集預測值,類型:list\n",
  549.     "    '''\n",
  550.     "    trees=buildRandomForest(trainSet,max_depth, min_size, sample_size, n_trees, n_features)\n",
  551.     "    predition=predictByForest(trees,testSet)\n",
  552.     "    return predition"
  553.    ]
  554.   },
  555.   {
  556.    "cell_type": "code",
  557.    "execution_count": 40,
  558.    "metadata": {},
  559.    "outputs": [],
  560.    "source": [
  561.     "def evaluteAlgorithm(trainSet,algorithm,n_folds,*args):\n",
  562.     "    '''\n",
  563.     "    評價算法函數\n",
  564.     "     Args:\n",
  565.     "        dataset     原始數據集\n",
  566.     "        algorithm   使用的算法\n",
  567.     "        n_folds     數據的份數,數據集交叉驗證的份數,采用無放回抽取\n",
  568.     "        *args       其他的參數\n",
  569.     "    Returns:\n",
  570.     "        scores      模型得分\n",
  571.     "    '''\n",
  572.     "    folds = createCrossValideSets(trainSet, n_folds)\n",
  573.     "    scores = list()\n",
  574.     "    # 每次循環從 folds 從取出一個 fold 作為測試集,其余作為訓練集,遍歷整個 folds ,實現交叉驗證\n",
  575.     "    for fold in folds:\n",
  576.     "        train_set = list(folds)\n",
  577.     "        train_set.remove(fold)\n",
  578.     "        # 將多個 fold 列表組合成一個 train_set 列表, 類似 union all\n",
  579.     "        \"\"\"\n",
  580.     "        In [20]: l1=[[1, 2, 'a'], [11, 22, 'b']]\n",
  581.     "        In [21]: l2=[[3, 4, 'c'], [33, 44, 'd']]\n",
  582.     "        In [22]: l=[]\n",
  583.     "        In [23]: l.append(l1)\n",
  584.     "        In [24]: l.append(l2)\n",
  585.     "        In [25]: l\n",
  586.     "        Out[25]: [[[1, 2, 'a'], [11, 22, 'b']], [[3, 4, 'c'], [33, 44, 'd']]]\n",
  587.     "        In [26]: sum(l, [])\n",
  588.     "        Out[26]: [[1, 2, 'a'], [11, 22, 'b'], [3, 4, 'c'], [33, 44, 'd']]\n",
  589.     "        \"\"\"\n",
  590.     "        train_set = sum(train_set, [])\n",
  591.     "        test_set = list()\n",
  592.     "        # fold 表示從原始數據集 dataset 提取出來的測試集\n",
  593.     "#         for row in fold:\n",
  594.     "#             row_copy = list(row)\n",
  595.     "#             row_copy[-1] = None\n",
  596.     "#             test_set.append(row_copy)\n",
  597.     "        predicted = algorithm(train_set, fold, *args)\n",
  598.     "    \n",
  599.     "        actual = [row[-1] for row in fold]\n",
  600.     "\n",
  601.     "        # 計算隨機森林的預測結果的正確率\n",
  602.     "        accuracyValue = accuracy(predicted,actual)\n",
  603.     "        scores.append(accuracyValue)\n",
  604.     "    return scores"
  605.    ]
  606.   },
  607.   {
  608.    "cell_type": "code",
  609.    "execution_count": 41,
  610.    "metadata": {},
  611.    "outputs": [
  612.     {
  613.      "name": "stdout",
  614.      "output_type": "stream",
  615.      "text": [
  616.       "隨機因子= 0.13436424411240122\n",
  617.       "決策樹個數: 1\n",
  618.       "模型得分: [87.8048780487805, 90.2439024390244, 92.6829268292683, 85.36585365853658, 95.1219512195122]\n",
  619.       "平均準確度: 90.244%\n",
  620.       "隨機因子= 0.13436424411240122\n",
  621.       "決策樹個數: 10\n",
  622.       "模型得分: [92.6829268292683, 92.6829268292683, 87.8048780487805, 78.04878048780488, 100.0]\n",
  623.       "平均準確度: 90.244%\n"
  624.      ]
  625.     }
  626.    ],
  627.    "source": [
  628.     "    \n",
  629.     "    #綜合測試函數\n",
  630.     "    n_folds = 5        # 分成5份數據,進行交叉驗證\n",
  631.     "    max_depth = 20     # 調參(自己修改) #決策樹深度不能太深,不然容易導致過擬合\n",
  632.     "    min_size = 1       # 決策樹的葉子節點最少的元素數量\n",
  633.     "    sample_size = 1.0  # 做決策樹時候的樣本的比例\n",
  634.     "    # n_features = int((len(dataset[0])-1))\n",
  635.     "    n_features = 15     # 調參(自己修改) #準確性與多樣性之間的權衡\n",
  636.     "    for n_trees in [1, 10]:  # 理論上樹是越多越好\n",
  637.     "        scores = evaluteAlgorithm(dataSet, randomForest, n_folds, max_depth, min_size, sample_size, n_trees, n_features)\n",
  638.     "        # 每一次執行本文件時都能產生同一個隨機數\n",
  639.     "        random.seed(1)\n",
  640.     "        print('隨機因子=', random.random())  # 每一次執行本文件時都能產生同一個隨機數\n",
  641.     "        print('決策樹個數: %d' % n_trees)  # 輸出決策樹個數\n",
  642.     "        print('模型得分: %s' % scores)  # 輸出五份隨機樣本的模型得分\n",
  643.     "        print('平均準確度: %.3f%%' % (sum(scores)/float(len(scores))))  # 輸出五份隨機樣本的平均準確度\n"
  644.    ]
  645.   },
  646.   {
  647.    "cell_type": "code",
  648.    "execution_count": 42,
  649.    "metadata": {},
  650.    "outputs": [
  651.     {
  652.      "name": "stdout",
  653.      "output_type": "stream",
  654.      "text": [
  655.       "隨機因子= 0.13436424411240122\n",
  656.       "決策樹個數: 1\n",
  657.       "模型得分: [80.48780487804879, 75.60975609756098, 73.17073170731707, 75.60975609756098, 78.04878048780488]\n",
  658.       "平均準確度: 76.585%\n",
  659.       "隨機因子= 0.13436424411240122\n",
  660.       "決策樹個數: 10\n",
  661.       "模型得分: [87.8048780487805, 85.36585365853658, 90.2439024390244, 78.04878048780488, 92.6829268292683]\n",
  662.       "平均準確度: 86.829%\n"
  663.      ]
  664.     }
  665.    ],
  666.    "source": [
  667.     "    sample_size =0.5  # 做決策樹時候的樣本的比例\n",
  668.     "    \n",
  669.     "    for n_trees in [1, 10]:  # 理論上樹是越多越好\n",
  670.     "        scores = evaluteAlgorithm(dataSet, randomForest, n_folds, max_depth, min_size, sample_size, n_trees, n_features)\n",
  671.     "        # 每一次執行本文件時都能產生同一個隨機數\n",
  672.     "        random.seed(1)\n",
  673.     "        print('隨機因子=', random.random())  # 每一次執行本文件時都能產生同一個隨機數\n",
  674.     "        print('決策樹個數: %d' % n_trees)  # 輸出決策樹個數\n",
  675.     "        print('模型得分: %s' % scores)  # 輸出五份隨機樣本的模型得分\n",
  676.     "        print('平均準確度: %.3f%%' % (sum(scores)/float(len(scores))))  # 輸出五份隨機樣本的平均準確度"
  677.    ]
  678.   }
  679. ],
  680. 余下見附件
復制代碼

全部資料51hei下載地址:
隨機森林例子.zip (99.15 KB, 下載次數: 11)

評分

參與人數 1黑幣 +50 收起 理由
admin + 50 共享資料的黑幣獎勵!

查看全部評分

分享到:  QQ好友和群QQ好友和群 QQ空間QQ空間 騰訊微博騰訊微博 騰訊朋友騰訊朋友
收藏收藏 分享淘帖 頂 踩
回復

使用道具 舉報

您需要登錄后才可以回帖 登錄 | 立即注冊

本版積分規則

手機版|小黑屋|51黑電子論壇 |51黑電子論壇6群 QQ 管理員QQ:125739409;技術交流QQ群281945664

Powered by 單片機教程網

快速回復 返回頂部 返回列表
主站蜘蛛池模板: 国产一区高清 | 春色av| 精品国产区 | 亚洲毛片在线 | 男女激情网站免费 | 久久国产精品久久 | 99久久久久 | 国产精品久久久久久网站 | 精精国产xxxx视频在线 | 国产精品99久久免费观看 | 亚洲精品视频免费观看 | 中文字幕在线精品 | www.伊人.com | 午夜专区| 欧美二区在线 | 久久久久国产精品人 | 国产精品久久久亚洲 | 伊人电影院av | 国产美女在线免费观看 | 国产精品日韩一区二区 | 久久久久99 | 国产在线中文字幕 | 国产福利在线免费观看 | 成人午夜免费福利视频 | 日本天天操 | 精品国产免费人成在线观看 | 欧美精品第一页 | 色www精品视频在线观看 | 精品久久电影 | 国产日韩一区二区三区 | 99久久99热这里只有精品 | 免费中文字幕 | jizz视频| 日韩在线观看中文字幕 | 欧美全黄 | 免费在线观看一区二区 | 99成人精品 | 天天视频一区二区三区 | 欧美一级片在线播放 | 97视频人人澡人人爽 | 国产一区二区三区免费 |