Point Cloud Library (PCL) 1.12.1
fern_evaluator.hpp
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/feature_handler.h>
42#include <pcl/ml/ferns/fern.h>
43#include <pcl/ml/stats_estimator.h>
44
45#include <vector>
46
47namespace pcl {
48
49template <class FeatureType,
50 class DataSet,
51 class LabelType,
52 class ExampleIndex,
53 class NodeType>
55{}
56
57template <class FeatureType,
58 class DataSet,
59 class LabelType,
60 class ExampleIndex,
61 class NodeType>
63{}
64
65template <class FeatureType,
66 class DataSet,
67 class LabelType,
68 class ExampleIndex,
69 class NodeType>
70void
75 DataSet& data_set,
76 std::vector<ExampleIndex>& examples,
77 std::vector<LabelType>& label_data)
78{
79 const std::size_t num_of_examples = examples.size();
80 const std::size_t num_of_branches = stats_estimator.getNumOfBranches();
81 const std::size_t num_of_features = fern.getNumOfFeatures();
82
83 label_data.resize(num_of_examples);
84
85 std::vector<std::vector<float>> results(num_of_features);
86 std::vector<std::vector<unsigned char>> flags(num_of_features);
87 std::vector<std::vector<unsigned char>> branch_indices(num_of_features);
88
89 for (std::size_t feature_index = 0; feature_index < num_of_features;
90 ++feature_index) {
91 results[feature_index].reserve(num_of_examples);
92 flags[feature_index].reserve(num_of_examples);
93 branch_indices[feature_index].reserve(num_of_examples);
94
95 feature_handler.evaluateFeature(fern.accessFeature(feature_index),
96 data_set,
97 examples,
98 results[feature_index],
99 flags[feature_index]);
100 stats_estimator.computeBranchIndices(results[feature_index],
101 flags[feature_index],
102 fern.accessThreshold(feature_index),
103 branch_indices[feature_index]);
104 }
105
106 for (std::size_t example_index = 0; example_index < num_of_examples;
107 ++example_index) {
108 std::size_t node_index = 0;
109 for (std::size_t feature_index = 0; feature_index < num_of_features;
110 ++feature_index) {
111 node_index *= num_of_branches;
112 node_index += branch_indices[feature_index][example_index];
113 }
114
115 label_data[example_index] = stats_estimator.getLabelOfNode(fern[node_index]);
116 }
117}
118
119template <class FeatureType,
120 class DataSet,
121 class LabelType,
122 class ExampleIndex,
123 class NodeType>
124void
129 DataSet& data_set,
130 std::vector<ExampleIndex>& examples,
131 std::vector<LabelType>& label_data)
132{
133 const std::size_t num_of_examples = examples.size();
134 const std::size_t num_of_branches = stats_estimator.getNumOfBranches();
135 const std::size_t num_of_features = fern.getNumOfFeatures();
136
137 std::vector<std::vector<float>> results(num_of_features);
138 std::vector<std::vector<unsigned char>> flags(num_of_features);
139 std::vector<std::vector<unsigned char>> branch_indices(num_of_features);
140
141 for (std::size_t feature_index = 0; feature_index < num_of_features;
142 ++feature_index) {
143 results[feature_index].reserve(num_of_examples);
144 flags[feature_index].reserve(num_of_examples);
145 branch_indices[feature_index].reserve(num_of_examples);
146
147 feature_handler.evaluateFeature(fern.accessFeature(feature_index),
148 data_set,
149 examples,
150 results[feature_index],
151 flags[feature_index]);
152 stats_estimator.computeBranchIndices(results[feature_index],
153 flags[feature_index],
154 fern.accessThreshold(feature_index),
155 branch_indices[feature_index]);
156 }
157
158 for (std::size_t example_index = 0; example_index < num_of_examples;
159 ++example_index) {
160 std::size_t node_index = 0;
161 for (std::size_t feature_index = 0; feature_index < num_of_features;
162 ++feature_index) {
163 node_index *= num_of_branches;
164 node_index += branch_indices[feature_index][example_index];
165 }
166
167 label_data[example_index] = stats_estimator.getLabelOfNode(fern[node_index]);
168 }
169}
170
171template <class FeatureType,
172 class DataSet,
173 class LabelType,
174 class ExampleIndex,
175 class NodeType>
176void
181 DataSet& data_set,
182 std::vector<ExampleIndex>& examples,
183 std::vector<NodeType*>& nodes)
184{
185 const std::size_t num_of_examples = examples.size();
186 const std::size_t num_of_branches = stats_estimator.getNumOfBranches();
187 const std::size_t num_of_features = fern.getNumOfFeatures();
188
189 nodes.reserve(num_of_examples);
190
191 std::vector<std::vector<float>> results(num_of_features);
192 std::vector<std::vector<unsigned char>> flags(num_of_features);
193 std::vector<std::vector<unsigned char>> branch_indices(num_of_features);
194
195 for (std::size_t feature_index = 0; feature_index < num_of_features;
196 ++feature_index) {
197 results[feature_index].reserve(num_of_examples);
198 flags[feature_index].reserve(num_of_examples);
199 branch_indices[feature_index].reserve(num_of_examples);
200
201 feature_handler.evaluateFeature(fern.accessFeature(feature_index),
202 data_set,
203 examples,
204 results[feature_index],
205 flags[feature_index]);
206 stats_estimator.computeBranchIndices(results[feature_index],
207 flags[feature_index],
208 fern.accessThreshold(feature_index),
209 branch_indices[feature_index]);
210 }
211
212 for (std::size_t example_index = 0; example_index < num_of_examples;
213 ++example_index) {
214 std::size_t node_index = 0;
215 for (std::size_t feature_index = 0; feature_index < num_of_features;
216 ++feature_index) {
217 node_index *= num_of_branches;
218 node_index += branch_indices[feature_index][example_index];
219 }
220
221 nodes.push_back(&(fern[node_index]));
222 }
223}
224
225} // namespace pcl
Utility class interface which is used for creating and evaluating features.
virtual void evaluateFeature(const FeatureType &feature, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< float > &results, std::vector< unsigned char > &flags) const =0
Evaluates a feature on the specified data.
virtual ~FernEvaluator()
Destructor.
void evaluate(pcl::Fern< FeatureType, NodeType > &fern, pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler, pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data)
Evaluates the specified examples using the supplied tree.
void evaluateAndAdd(pcl::Fern< FeatureType, NodeType > &fern, pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler, pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data)
Evaluates the specified examples using the supplied tree and adds the results to the supplied results...
FernEvaluator()
Constructor.
void getNodes(pcl::Fern< FeatureType, NodeType > &fern, pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler, pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< NodeType * > &nodes)
Evaluates the specified examples using the supplied tree.
Class representing a Fern.
Definition: fern.h:49
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
Definition: fern.h:186
std::size_t getNumOfFeatures()
Returns the number of features the Fern has.
Definition: fern.h:79
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
Definition: fern.h:166
virtual std::size_t getNumOfBranches() const =0
Returns the number of brances a node can have (e.g.
virtual LabelDataType getLabelOfNode(NodeType &node) const =0
Returns the label of the specified node.
virtual void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const =0
Computes the branch indices obtained by the specified threshold on the supplied feature evaluation re...
Define standard C methods and C++ classes that are common to all methods.