基于ID3决策树算法的实现(Python版)

  • Post category:Python

基于ID3决策树算法的实现(Python版)

决策树是一种常用的机器学习算法,它可以用于分类和回归问题。ID3是一种常用的决策树算法它基于信息熵来选择最佳划分属性。本文将介绍如何使用Python实现基于ID3决策树算法的分类器。

1. 数据集

我们将使用一个简单的数据集来演示如何使用ID3算法构建决策树。这个数据集包含5个样本,每个样本两个特征:Outlook和Temperature。Outlook有三个可能的取值:Sunny、Overcast和Rainy;Temperature有两个可能的取值:Hot和Mild。每个样本都有一个类别标签:PlayTennis或NotPlayTennis。以下是数据集的示例:

Outlook Temperature PlayTennis
Sunny Hot No
Sunny Hot No
Overcast Hot Yes
Rainy Mild Yes
Rainy Hot No

2. ID3算法

ID3算法是一种基于信息熵的决策树算法。它的基本思想是最佳划分属性,使得划分后的子集尽可能地纯净。信息熵是一个用于衡量数据集纯度的指标,它的定义如下:

$$
H(X) = -\sum_{i=1}^{n}p_i\log_2p_i
$$

其中,$X$是一个数据集,$n$是$X$中不同类别的数量,$p_i$是类别$i$在$X$中出现的概率。

ID3算法的具体实现下:

  1. 计算数据集的信息熵$H(X)$。
  2. 对于每个特征$A$,计算它的信息增益$IG(A)$,并选择信息增益最大的特征作为划分属性。
  3. 使用划分属性将集划分为多个子集,每个子集对应一个特征值。
  4. 对于每个子集,如果它的类别标签不全相同,则递归地应用上述步骤,直到所有子集的类别标签完全相同或者没有更多特征可用止。

信息增益是一个用于衡量特征对数据集分类能力的指标,它的定义如下:

$$
IG(A) = H(X) – \sum_{v\in Values(A)}\frac{|X_v|}{|X|}H(X_v$$

其中,$A$是一个特征,$Values(A)$是$A$的所有可能取值,$X_v$是$X$中所有特征$A$取值为$v$的样本组成的子集。

3. Python实现

我们将使用Python实现基于ID3算法的决策树分类器。以下是整的代码:

import math
from collections import Counter

class DecisionTree:
    def __init__(self):
        self.tree = {}

    def fit(self, X, y):
        self.tree = self.build_tree(X, y)

    def predict(self, X):
        return [self.predict_one(x, self.tree) for x in X]

    def predict_one(self, x, tree):
        if not isinstance(tree, dict):
            return tree
        feature, value_dict = next(iter(tree.items()))
        value = x.get(feature)
        if value not in value_dict:
            return Counter(value_dict.values()).most_common(1)[0][0]
        return self.predict_one(x, value_dict[value])

    def build_tree(self, X, y):
        if len(set(y)) == 1:
            return y[0]
        if not X:
            return Counter(y).most_common(1)[0][0]
        best_feature = self.choose_best_feature(X, y)
        tree = {best_feature: {}}
        for value in set(x[best_feature] for x in X):
            X_v = [x for x in X if x[best_feature] == value]
            y_v = [y[i] for i, x in enumerate(X) if x[best_feature] == value]
            tree[best_feature][value] = self.build_tree(X_v, y_v)
        return tree

    def choose_best_feature(self, X, y):
        base_entropy = self.entropy(y)
        best_info_gain, best_feature = -1, None
        for feature in X[0]:
            info_gain = base_entropy - self.conditional_entropy(X, y, feature)
            if info_gain > best_info_gain:
                best_info_gain, best_feature = info_gain, feature
        return best_feature

    def entropy(self, y):
        counter = Counter(y)
        probs = [counter[c] / len(y) for c in set(y)]
        return -sum(p * math.log2(p) for p in probs)

    def conditional_entropy(self, X, y, feature):
       _values = set(x[feature] for x in X)
        probs = [sum(1 for x in X if x[feature] == value) / len(X) for value in feature_values]
 entropies = [self.entropy([y[i] for i, x in enumerate(X) if x[feature] == value]) for value in feature_values]
        return sum(p * e for p, e in zip(pro, entropies))

这个代码实现了一个名为DecisionTree的类,它包含三个方法:

  • fit(X, y):用于训练决策树分类器,其中X是一个二维数组,每行表示一个本,每列表示一个特征;y是一个一维数组,表示每个样本的类别标签。
  • predict(X):用于对新样进行分类,其中X是一个二维数组,每行表示一个样本,每列表示一个特征;返回一个一维数组,表示每个样本的类别标签。
  • build_tree(X, y):用于构建决策树,其中X和y的含义与fit相同。

4. 示例

以下是如何使用上述代码对数据集进行分类的示例:

X = [
    {'Outlook': 'Sunny', 'Temperature': 'Hot'},
    {'Outlook': 'Sunny', 'Temperature': 'Hot'},
    {'Outlook': 'Overcast', 'Temperature': 'Hot'},
 {'Outlook': 'Rainy', 'Temperature': 'Mild'},
    {'Outlook': 'Rainy', 'Temperature': 'Hot'},
]
y = ['No', 'No', 'Yes', 'Yes', 'No']

clf = DecisionTree()
clf.fit(X, y)

X_test = [
    {'Outlook': 'Sunny', 'Temperature': 'Mild'},
    {'Outlook': 'Overcast', 'Temperature': 'Mild'},
    {'Outlook': 'Rainy', '': 'Mild'},
]
y_pred = clf.predict(X_test)

print(y_pred)  # ['No', 'Yes', 'Yes']

这个示例将使用上述代码对数据集进行分类,并输出预测。

5. 总结

本文介绍了如何使用Python实现基于ID3算法的决策树分类器。决策树是一种常用的机器学习算法,它可以用于分类和回归问题。ID3算法是一种基于信息熵的决策树算法,它的基本思想是选择最佳划属性,使得划分后的子集尽可能地纯净。在实际应用中,我们可以根据数据集的特点选择合适的决树算法,并使用Python实现相应的分类器。

示例说明

示例1

在示例1中,我们使用了一个包含5个样本的数据集,每个样本有两个特征:Outlook和Temperature。我们使用DecisionTree类训练了一个决策树分类器,并使用X_test对新样本进行了分类。最终输出了预测结果。

示例2

在示例2中,我们使用了一个包含14个样本的数据集,每个样本有四个特征:sepal length、sepal width、petal length和petal width。我们使用DecisionTree类训练了一个决策树分类器,并使用X_test对新样本进行了分类。最终输出了预测结果。