Point Cloud Library (PCL) 1.13.0
Loading...
Searching...
No Matches
decision_tree_trainer.hpp
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40namespace pcl {
41
42template <class FeatureType,
43 class DataSet,
44 class LabelType,
45 class ExampleIndex,
46 class NodeType>
49: max_tree_depth_(15)
50, num_of_features_(1000)
51, num_of_thresholds_(10)
52, feature_handler_(nullptr)
53, stats_estimator_(nullptr)
54, data_set_()
55, label_data_()
56, examples_()
57, decision_tree_trainer_data_provider_()
58, random_features_at_split_node_(false)
59{}
60
61template <class FeatureType,
62 class DataSet,
63 class LabelType,
64 class ExampleIndex,
65 class NodeType>
67 ~DecisionTreeTrainer() = default;
68
69template <class FeatureType,
70 class DataSet,
71 class LabelType,
72 class ExampleIndex,
73 class NodeType>
74void
77{
78 // create random features
79 std::vector<FeatureType> features;
80
81 if (!random_features_at_split_node_)
82 feature_handler_->createRandomFeatures(num_of_features_, features);
83
84 // recursively build decision tree
85 NodeType root_node;
86 tree.setRoot(root_node);
87
88 if (decision_tree_trainer_data_provider_) {
89 std::cerr << "use decision_tree_trainer_data_provider_" << std::endl;
90
91 decision_tree_trainer_data_provider_->getDatasetAndLabels(
92 data_set_, label_data_, examples_);
93 trainDecisionTreeNode(
94 features, examples_, label_data_, max_tree_depth_, tree.getRoot());
95 label_data_.clear();
96 data_set_.clear();
97 examples_.clear();
98 }
99 else {
100 trainDecisionTreeNode(
101 features, examples_, label_data_, max_tree_depth_, tree.getRoot());
102 }
103}
104
105template <class FeatureType,
106 class DataSet,
107 class LabelType,
108 class ExampleIndex,
109 class NodeType>
110void
112 trainDecisionTreeNode(std::vector<FeatureType>& features,
113 std::vector<ExampleIndex>& examples,
114 std::vector<LabelType>& label_data,
115 const std::size_t max_depth,
116 NodeType& node)
117{
118 const std::size_t num_of_examples = examples.size();
119 if (num_of_examples == 0) {
120 PCL_ERROR(
121 "Reached invalid point in decision tree training: Number of examples is 0!\n");
122 return;
123 };
124
125 if (max_depth == 0) {
126 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
127 return;
128 };
129
130 if (examples.size() < min_examples_for_split_) {
131 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
132 return;
133 }
134
135 if (random_features_at_split_node_) {
136 features.clear();
137 feature_handler_->createRandomFeatures(num_of_features_, features);
138 }
139
140 std::vector<float> feature_results;
141 std::vector<unsigned char> flags;
142
143 feature_results.reserve(num_of_examples);
144 flags.reserve(num_of_examples);
145
146 // find best feature for split
147 int best_feature_index = -1;
148 float best_feature_threshold = 0.0f;
149 float best_feature_information_gain = 0.0f;
150
151 const std::size_t num_of_features = features.size();
152 for (std::size_t feature_index = 0; feature_index < num_of_features;
153 ++feature_index) {
154 // evaluate features
155 feature_handler_->evaluateFeature(
156 features[feature_index], data_set_, examples, feature_results, flags);
157
158 // get list of thresholds
159 if (!thresholds_.empty()) {
160 // compute information gain for each threshold and store threshold with highest
161 // information gain
162 for (std::size_t threshold_index = 0; threshold_index < thresholds_.size();
163 ++threshold_index) {
164
165 const float information_gain =
166 stats_estimator_->computeInformationGain(data_set_,
167 examples,
168 label_data,
169 feature_results,
170 flags,
171 thresholds_[threshold_index]);
172
173 if (information_gain > best_feature_information_gain) {
174 best_feature_information_gain = information_gain;
175 best_feature_index = static_cast<int>(feature_index);
176 best_feature_threshold = thresholds_[threshold_index];
177 }
178 }
179 }
180 else {
181 std::vector<float> thresholds;
182 thresholds.reserve(num_of_thresholds_);
183 createThresholdsUniform(num_of_thresholds_, feature_results, thresholds);
184
185 // compute information gain for each threshold and store threshold with highest
186 // information gain
187 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
188 ++threshold_index) {
189 const float threshold = thresholds[threshold_index];
190
191 // compute information gain
192 const float information_gain = stats_estimator_->computeInformationGain(
193 data_set_, examples, label_data, feature_results, flags, threshold);
194
195 if (information_gain > best_feature_information_gain) {
196 best_feature_information_gain = information_gain;
197 best_feature_index = static_cast<int>(feature_index);
198 best_feature_threshold = threshold;
199 }
200 }
201 }
202 }
203
204 if (best_feature_index == -1) {
205 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
206 return;
207 }
208
209 // get branch indices for best feature and best threshold
210 std::vector<unsigned char> branch_indices;
211 branch_indices.reserve(num_of_examples);
212 {
213 feature_handler_->evaluateFeature(
214 features[best_feature_index], data_set_, examples, feature_results, flags);
215
216 stats_estimator_->computeBranchIndices(
217 feature_results, flags, best_feature_threshold, branch_indices);
218 }
219
220 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
221
222 // separate data
223 {
224 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
225
226 std::vector<std::size_t> branch_counts(num_of_branches, 0);
227 for (std::size_t example_index = 0; example_index < num_of_examples;
228 ++example_index) {
229 ++branch_counts[branch_indices[example_index]];
230 }
231
232 node.feature = features[best_feature_index];
233 node.threshold = best_feature_threshold;
234 node.sub_nodes.resize(num_of_branches);
235
236 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
237 if (branch_counts[branch_index] == 0) {
238 NodeType branch_node;
239 stats_estimator_->computeAndSetNodeStats(
240 data_set_, examples, label_data, branch_node);
241 // branch_node->num_of_sub_nodes = 0;
242
243 node.sub_nodes[branch_index] = branch_node;
244
245 continue;
246 }
247
248 std::vector<LabelType> branch_labels;
249 std::vector<ExampleIndex> branch_examples;
250 branch_labels.reserve(branch_counts[branch_index]);
251 branch_examples.reserve(branch_counts[branch_index]);
252
253 for (std::size_t example_index = 0; example_index < num_of_examples;
254 ++example_index) {
255 if (branch_indices[example_index] == branch_index) {
256 branch_examples.push_back(examples[example_index]);
257 branch_labels.push_back(label_data[example_index]);
258 }
259 }
260
261 trainDecisionTreeNode(features,
262 branch_examples,
263 branch_labels,
264 max_depth - 1,
265 node.sub_nodes[branch_index]);
266 }
267 }
268}
269
270template <class FeatureType,
271 class DataSet,
272 class LabelType,
273 class ExampleIndex,
274 class NodeType>
275void
277 createThresholdsUniform(const std::size_t num_of_thresholds,
278 std::vector<float>& values,
279 std::vector<float>& thresholds)
280{
281 // estimate range of values
282 float min_value = ::std::numeric_limits<float>::max();
283 float max_value = -::std::numeric_limits<float>::max();
284
285 const std::size_t num_of_values = values.size();
286 for (std::size_t value_index = 0; value_index < num_of_values; ++value_index) {
287 const float value = values[value_index];
288
289 if (value < min_value)
290 min_value = value;
291 if (value > max_value)
292 max_value = value;
293 }
294
295 const float range = max_value - min_value;
296 const float step = range / static_cast<float>(num_of_thresholds + 2);
297
298 // compute thresholds
299 thresholds.resize(num_of_thresholds);
300
301 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds;
302 ++threshold_index) {
303 thresholds[threshold_index] =
304 min_value + step * (static_cast<float>(threshold_index + 1));
305 }
306}
307
308} // namespace pcl
Class representing a decision tree.
NodeType & getRoot()
Returns the root node of the tree.
void setRoot(const NodeType &root)
Sets the root node of the tree.
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
void trainDecisionTreeNode(std::vector< FeatureType > &features, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data, std::size_t max_depth, NodeType &node)
Trains a decision tree node from the specified features, label data, and examples.
void train(DecisionTree< NodeType > &tree)
Trains a decision tree using the set training data and settings.
virtual ~DecisionTreeTrainer()
Destructor.