1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| class DecisionTree(object): def __init__(self, classes, features, max_depth=10, min_samples_split=10, impurity_t='entropy'):
self.classes = classes self.features = features self.max_depth = max_depth self.min_samples_split = min_samples_split self.impurity_t = impurity_t self.root = None self.tree = defaultdict(list)
def get_params(self, deep): return {'classes': self.classes, 'features': self.features, 'max_depth': self.max_depth, 'min_samples_split': self.min_samples_split, 'impurity_t': self.impurity_t}
def set_params(self, **parameters): for parameter, value in parameters.items(): setattr(self, parameter, value) return self
def impurity(self, label): ''' 计算不纯度,根据传入参数计算信息熵或gini系数 label是numpy一维数组:根据当前特征划分后的标签组成 ''' cnt, total = Counter(label), float(len(label)) probs = [cnt[v] / total for v in cnt] if self.impurity_t == 'gini': return 1 - sum([p * p for p in probs]) return -sum([p * np.log2(p) for p in probs if p > 0])
def gain(self, feature, label) -> tuple:
p_impurity = self.impurity(label)
f_index = defaultdict(list) for idx, v in enumerate(feature): f_index[v].append(idx)
c_impurity = 0 for v in f_index: f_l = label[f_index[v]] c_impurity += self.impurity(f_l) * len(f_l) / len(label)
r = self.impurity(feature) r = (p_impurity - c_impurity) / (r if r != 0 else 1) return r, f_index
def expand_node(self, feature, label, depth, used_features) -> tuple:
if len(set(label)) == 0: return label[0] most = Counter(label).most_common(1)[0][0] if depth > self.max_depth or len(label) < self.min_samples_split: return most
bestf, max_gain, bestf_idx = -1, -1, None for f in range(len(self.features)): if f in used_features: continue f_gain, f_idx = self.gain(feature[:, f], label) if bestf < 0 or f_gain > max_gain: bestf, max_gain, bestf_idx = f, f_gain, f_idx
if bestf < 0: return most
children = {} new_used_features = used_features + [bestf] for v in bestf_idx: c_idx = bestf_idx[v] children[v] = self.expand_node(feature[c_idx, :], label[c_idx], depth + 1, new_used_features) self.tree[depth].append(self.features[bestf]) return (bestf, children, most)
def traverse_node(self, node, feature): assert len(self.features) == len(feature) if type(node) is not tuple: return node fv = feature[node[0]] if fv in node[1]: return self.traverse_node(node[1][fv], feature) return node[-1]
def fit(self, feature, label): assert len(self.features) == len( feature[0]) self.root = self.expand_node( feature, label, depth=1, used_features=[])
def predict(self, feature): assert len(feature.shape) == 1 or len(feature.shape) == 2 if len(feature.shape) == 1: return self.traverse_node(self.root, feature) return np.array([self.traverse_node(self.root, f) for f in feature])
|