Add extremely random forest implementation to TF/contrib.
Change: 118888391
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 366af6b..a5ba85f 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -21,6 +21,7 @@
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/skflow",
+ "//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/util:util_py",
],
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
new file mode 100644
index 0000000..33c66a6b
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -0,0 +1,194 @@
+# Tensorflow code for training random forests.
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
+
+cc_library(
+ name = "tree_utils",
+ srcs = ["core/ops/tree_utils.cc"],
+ hdrs = [
+ "core/ops/tree_utils.h",
+ ],
+ deps = [
+ "//google/protobuf",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_custom_op_library(
+ name = "python/ops/_inference_ops.so",
+ srcs = [
+ "core/ops/tree_predictions_op.cc",
+ ],
+ deps = [":tree_utils"],
+)
+
+tf_custom_op_library(
+ name = "python/ops/_training_ops.so",
+ srcs = [
+ "core/ops/best_splits_op.cc",
+ "core/ops/count_extremely_random_stats_op.cc",
+ "core/ops/finished_nodes_op.cc",
+ "core/ops/grow_tree_op.cc",
+ "core/ops/sample_inputs_op.cc",
+ "core/ops/scatter_add_ndim_op.cc",
+ "core/ops/update_fertile_slots_op.cc",
+ ],
+ deps = [":tree_utils"],
+)
+
+py_library(
+ name = "ops_lib",
+ srcs = [
+ "__init__.py",
+ "python/ops/inference_ops.py",
+ "python/ops/training_ops.py",
+ ],
+ data = [
+ "python/ops/_inference_ops.so",
+ "python/ops/_training_ops.so",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "best_splits_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/best_splits_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "count_extremely_random_stats_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/count_extremely_random_stats_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "grow_tree_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/grow_tree_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "finished_nodes_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/finished_nodes_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "sample_inputs_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/sample_inputs_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "scatter_add_ndim_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "tree_predictions_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/tree_predictions_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_test(
+ name = "update_fertile_slots_op_test",
+ size = "small",
+ srcs = ["python/kernel_tests/update_fertile_slots_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_library(
+ name = "tensor_forest_py",
+ srcs = ["python/tensor_forest.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops_lib",
+ ],
+)
+
+py_test(
+ name = "tensor_forest_test",
+ size = "small",
+ srcs = ["python/tensor_forest_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":tensor_forest_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
diff --git a/tensorflow/contrib/tensor_forest/OWNERS b/tensorflow/contrib/tensor_forest/OWNERS
new file mode 100644
index 0000000..becb4d2
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/OWNERS
@@ -0,0 +1,3 @@
+dsculley
+gilberth
+thomaswc
\ No newline at end of file
diff --git a/tensorflow/contrib/tensor_forest/__init__.py b/tensorflow/contrib/tensor_forest/__init__.py
new file mode 100644
index 0000000..878d9c5
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Imports training and inference custom ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensor_forest.python.ops.inference_ops import *
+from tensorflow.contrib.tensor_forest.python.ops.training_ops import *
diff --git a/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc b/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc
new file mode 100644
index 0000000..3c454be
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc
@@ -0,0 +1,117 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// BestSplits returns the index of the best candidate for each finished node.
+// This decision is based on the Gini score of the pcw_candidate_split counts,
+// and the right-branch-taken counts inferred from pcw_total_splits.
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+
+namespace tensorflow {
+
+using tensorforest::BestFeature;
+
+
+REGISTER_OP("BestSplits")
+ .Input("finished_nodes: int32")
+ .Input("node_to_accumulator: int32")
+ .Input("pcw_candidate_splits: float")
+ .Input("pcw_total_splits: float")
+ .Output("split_indices: int32")
+ .Doc(R"doc(
+ Returns the index of the best split for each finished node.
+
+ The best split is the split with the lowest weighted Gini impurity,
+ as calculated from the statistics in `pcw_candidate_splits` and
+ `pcw_total_splits`.
+
+ finished_nodes:= A 1-d int32 tensor containing the indices of finished nodes.
+ node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+ fertile node i, or -1 if node i isn't fertile.
+ pcw_candidate_splits: `pcw_candidate_splits[a][s][c]` records how many
+ training examples have class c and have ended up in the fertile node
+ associated with accumulator slot a and then taken the *left* branch of
+ candidate split s.
+ pcw_total_splits: `pcw_total_splits[a][c]` records how many training examples
+ have class c and have ended up in the fertile node associated with
+ accumulator slot a. Between that and `pcw_candidate_splits`, the number of
+ examples taking the right branch of a split can be reconstructed.
+ split_indices: `split_indices[i]` contains the index of the split to use for
+ `finished_nodes[i]`.
+)doc");
+
+
+class BestSplits : public OpKernel {
+ public:
+ explicit BestSplits(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& finished = context->input(0);
+ const Tensor& node_to_accumulator = context->input(1);
+ const Tensor& pcw_candidate_splits = context->input(2);
+ const Tensor& pcw_total_splits = context->input(3);
+
+ OP_REQUIRES(context, finished.shape().dims() == 1,
+ errors::InvalidArgument(
+ "finished should be one-dimensional"));
+ OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+ errors::InvalidArgument(
+ "node_to_accumulator should be one-dimensional"));
+
+ OP_REQUIRES(context, pcw_candidate_splits.shape().dims() == 3,
+ errors::InvalidArgument(
+ "pcw_candidate_splits should be three-dimensional"));
+ OP_REQUIRES(context, pcw_total_splits.shape().dims() == 2,
+ errors::InvalidArgument(
+ "pcw_total_splits should be two-dimensional"));
+
+ OP_REQUIRES(
+ context,
+ pcw_candidate_splits.shape().dim_size(0) ==
+ pcw_total_splits.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of accumulators should be the same in pcw_candidate_splits "
+ "and pcw_total_splits."));
+
+ Tensor* output_splits = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, finished.shape(),
+ &output_splits));
+ auto best_splits = output_splits->unaligned_flat<int32>();
+
+ const auto finished_vec = finished.unaligned_flat<int32>();
+ const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+
+ const int32 num_finished = finished.shape().dim_size(0);
+
+ for (int i = 0; i < num_finished; i++) {
+ const int32 node = finished_vec(i);
+ const int32 accumulator = node_map(node);
+ if (accumulator < 0) {
+ LOG(ERROR) << "Something has gone wrong, we got a finished node that "
+ << "doesn't have an accumulator allocated to it.";
+ continue;
+ }
+ best_splits(i) = BestFeature(pcw_total_splits,
+ pcw_candidate_splits, accumulator);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("BestSplits").Device(DEVICE_CPU), BestSplits);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
new file mode 100644
index 0000000..ab5ac9c
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
@@ -0,0 +1,331 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// CountExtremelyRandomStats outputs count-deltas that should be added to
+// the node pcws, candidate split pcws, and total split pcws. It also outputs
+// the leaves that each input arrived to for use in SampleInputs. This is the
+// only op that involves tree traversal, and is constructed so that it can
+// be run in parallel on separate batches of data.
+#include <unordered_map>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+using std::get;
+using std::make_tuple;
+using std::pair;
+using std::tuple;
+
+using tensorforest::CHILDREN_INDEX;
+using tensorforest::FEATURE_INDEX;
+using tensorforest::LEAF_NODE;
+using tensorforest::FREE_NODE;
+
+using tensorforest::DecideNode;
+using tensorforest::Initialize;
+using tensorforest::IsAllInitialized;
+
+REGISTER_OP("CountExtremelyRandomStats")
+ .Attr("num_classes: int32")
+ .Input("input_data: float")
+
+ .Input("input_labels: int32")
+
+ .Input("tree: int32")
+ .Input("tree_thresholds: float")
+
+ .Input("node_to_accumulator: int32")
+
+ .Input("candidate_split_features: int32")
+ .Input("candidate_split_thresholds: float")
+
+ .Output("pcw_node_delta: float")
+ .Output("pcw_splits_indices: int32")
+ .Output("pcw_candidate_splits_delta: float")
+ .Output("pcw_totals_indices: int32")
+ .Output("pcw_total_splits_delta: float")
+
+ .Output("leaves: int32")
+ .Doc(R"doc(
+ Calculates incremental statistics for a batch of training data.
+
+ Each training example in `input_data` is sent through the decision tree
+ represented by `tree` and `tree_thresholds`. `pcw_node_delta[i]` is
+ incremented for every node i that it passes through, and the leaf it ends up
+ in is recorded in `leaves[i]`. Then, if the leaf is fertile and
+ initialized, the statistics for its corresponding accumulator slot
+ are updated in in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`.
+
+ The attr `num_classes` is needed to appropriately size the outputs.
+
+ input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+ gives the j-th feature of the i-th input.
+ input_labels: The training batch's labels; `input_labels[i]` is the class
+ of the i-th input.
+ tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
+ of the i-th node, `tree[0][i] + 1` gives the index of the right child of
+ the i-th node, and `tree[1][i]` gives the index of the feature used to
+ split the i-th node.
+ tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
+ node.
+ node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
+ is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1.
+ candidate_split_features: `candidate_split_features[a][s]` is the
+ index of the feature being considered by split s of accumulator slot a.
+ candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
+ threshold value being considered by split s of accumulator slot a.
+ pcw_node_delta: `pcw_node_delta[i][c]` is the number of training examples
+ in this training batch with class c that passed through node i.
+ pcw_splits_indices:= A 2-d tensor of shape (?, 3).
+ `pcw_splits_indices[i]` gives the coordinates of an entry in
+ candidate_split_per_class_weights that needs to be updated.
+ This is meant to be passed with `pcw_candidate_splits_delta` to a
+ scatter_add for candidate_split_per_class_weights:
+ training_ops.scatter_add_ndim(candidate_split_per_class_weights,
+ pcw_splits_indices, pcw_candidate_splits_delta)
+ pcw_candidate_splits_delta: `pcw_candidate_splits_delta[i]` is the
+ number of training examples in this training batch that correspond to
+ the i-th entry in `pcw_splits_indices` which took the *left* branch of
+ candidate split.
+ pcw_totals_indices: 'pcw_totals_indices` contains the indices (accumulator,
+ class) into total_per_class_weights to update with pcw_total_splits_delta.
+ pcw_total_splits_delta: `pcw_total_splits_delta[i]` is the number of
+ training examples in this batch that ended up in the fertile
+ node with accumulator and class indicated by `pcw_totals_indices[i]`.
+ leaves: `leaves[i]` is the leaf that input i ended up in.
+)doc");
+
+
+class CountExtremelyRandomStats : public OpKernel {
+ public:
+ explicit CountExtremelyRandomStats(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "num_classes", &num_classes_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_data = context->input(0);
+ const Tensor& input_labels = context->input(1);
+ const Tensor& tree_tensor = context->input(2);
+ const Tensor& tree_thresholds = context->input(3);
+ const Tensor& node_to_accumulator = context->input(4);
+ const Tensor& candidate_split_features = context->input(5);
+ const Tensor& candidate_split_thresholds = context->input(6);
+
+ // Check inputs.
+ OP_REQUIRES(context, input_data.shape().dims() == 2,
+ errors::InvalidArgument(
+ "input_data should be two-dimensional"));
+ OP_REQUIRES(context, input_labels.shape().dims() == 1,
+ errors::InvalidArgument(
+ "input_labels should be one-dimensional"));
+
+ OP_REQUIRES(context, tree_tensor.shape().dims() == 2,
+ errors::InvalidArgument(
+ "tree should be two-dimensional"));
+ OP_REQUIRES(context, tree_thresholds.shape().dims() == 1,
+ errors::InvalidArgument(
+ "tree_thresholds should be one-dimensional"));
+ OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+ errors::InvalidArgument(
+ "node_to_accumulator should be one-dimensional"));
+ OP_REQUIRES(context, candidate_split_features.shape().dims() == 2,
+ errors::InvalidArgument(
+ "candidate_split_features should be two-dimensional"));
+ OP_REQUIRES(context, candidate_split_thresholds.shape().dims() == 2,
+ errors::InvalidArgument(
+ "candidate_split_thresholds should be two-dimensional"));
+
+ OP_REQUIRES(
+ context,
+ input_data.shape().dim_size(0) == input_labels.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of inputs should be the same in "
+ "input_data and input_labels."));
+ OP_REQUIRES(
+ context,
+ tree_tensor.shape().dim_size(0) ==
+ tree_thresholds.shape().dim_size(0) &&
+ tree_tensor.shape().dim_size(0) ==
+ node_to_accumulator.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of nodes should be the same in "
+ "tree, tree_thresholds, and node_to_accumulator"));
+ OP_REQUIRES(
+ context,
+ candidate_split_features.shape() == candidate_split_thresholds.shape(),
+ errors::InvalidArgument(
+ "candidate_split_features and candidate_split_thresholds should be "
+ "the same shape."));
+
+ const int32 num_splits = candidate_split_features.shape().dim_size(1);
+
+ // node pcw delta
+ Tensor* output_node_pcw_delta = nullptr;
+ TensorShape node_pcw_shape;
+ node_pcw_shape.AddDim(tree_tensor.shape().dim_size(0));
+ node_pcw_shape.AddDim(num_classes_);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, node_pcw_shape,
+ &output_node_pcw_delta));
+ Initialize<float>(*output_node_pcw_delta, 0);
+ auto out_node = output_node_pcw_delta->tensor<float, 2>();
+
+ // leaves
+ Tensor* output_leaves = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(5, input_labels.shape(),
+ &output_leaves));
+ auto out_leaves = output_leaves->unaligned_flat<int32>();
+
+ const auto tree = tree_tensor.tensor<int32, 2>();
+ const auto thresholds = tree_thresholds.unaligned_flat<float>();
+ const auto labels = input_labels.unaligned_flat<int32>();
+ const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+ const auto split_features = candidate_split_features.tensor<int32, 2>();
+ const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
+
+ const int32 num_data = input_data.shape().dim_size(0);
+
+ // <accumulator, class> -> count delta
+ std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
+ // <accumulator, split, class> -> count delta
+ std::unordered_map<tuple<int32, int32, int32>,
+ int32, TupleIntHash> split_delta;
+ for (int i = 0; i < num_data; i++) {
+ const Tensor point = input_data.Slice(i, i+1);
+ int node_index = 0;
+ while (true) {
+ const int32 label = labels(i);
+ ++out_node(node_index, label);
+ int32 left_child = tree(node_index, CHILDREN_INDEX);
+ if (left_child == LEAF_NODE) {
+ out_leaves(i) = node_index;
+ const int32 accumulator = node_map(node_index);
+ // If the leaf is not fertile or is not yet initialized, we don't
+ // count it in the candidate/total split per-class-weights because
+ // it won't have any candidate splits yet.
+ if (accumulator >= 0 &&
+ IsAllInitialized(
+ candidate_split_features.Slice(accumulator,
+ accumulator + 1))) {
+ ++total_delta[std::make_pair(accumulator, label)];
+ for (int split = 0; split < num_splits; split++) {
+ if (!DecideNode(point, split_features(accumulator, split),
+ split_thresholds(accumulator, split))) {
+ ++split_delta[make_tuple(accumulator, split, label)];
+ }
+ }
+ }
+ break;
+ } else if (left_child == FREE_NODE) {
+ LOG(ERROR) << "Reached a free node, not good.";
+ out_leaves(i) = FREE_NODE;
+ break;
+ }
+ node_index = left_child +
+ DecideNode(point, tree(node_index, FEATURE_INDEX),
+ thresholds(node_index));
+ }
+ }
+
+ // candidate splits pcw indices
+ Tensor* output_candidate_pcw_indices = nullptr;
+ TensorShape candidate_pcw_shape;
+ candidate_pcw_shape.AddDim(split_delta.size());
+ candidate_pcw_shape.AddDim(3);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, candidate_pcw_shape,
+ &output_candidate_pcw_indices));
+ auto out_candidate_indices =
+ output_candidate_pcw_indices->tensor<int32, 2>();
+
+ // candidate splits pcw delta
+ Tensor* output_candidate_pcw_delta = nullptr;
+ TensorShape candidate_pcw_delta_shape;
+ candidate_pcw_delta_shape.AddDim(split_delta.size());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, candidate_pcw_delta_shape,
+ &output_candidate_pcw_delta));
+ auto out_candidate = output_candidate_pcw_delta->unaligned_flat<float>();
+
+ // total splits indices
+ Tensor* output_total_pcw_indices = nullptr;
+ TensorShape total_pcw_shape;
+ total_pcw_shape.AddDim(total_delta.size());
+ total_pcw_shape.AddDim(2);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(3, total_pcw_shape,
+ &output_total_pcw_indices));
+ auto out_total_indices = output_total_pcw_indices->tensor<int32, 2>();
+
+ // total splits delta
+ Tensor* output_total_pcw_delta = nullptr;
+ TensorShape total_pcw_delta_shape;
+ total_pcw_delta_shape.AddDim(total_delta.size());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(4, total_pcw_delta_shape,
+ &output_total_pcw_delta));
+ auto out_total = output_total_pcw_delta->unaligned_flat<float>();
+
+ // Copy total deltas to output.
+ int32 output_slot = 0;
+ for (const auto& updates : total_delta) {
+ out_total_indices(output_slot, 0) = updates.first.first;
+ out_total_indices(output_slot, 1) = updates.first.second;
+ out_total(output_slot) = updates.second;
+ ++output_slot;
+ }
+
+ // Copy split deltas to output.
+ output_slot = 0;
+ for (const auto& updates : split_delta) {
+ out_candidate_indices(output_slot, 0) = get<0>(updates.first);
+ out_candidate_indices(output_slot, 1) = get<1>(updates.first);
+ out_candidate_indices(output_slot, 2) = get<2>(updates.first);
+ out_candidate(output_slot) = updates.second;
+ ++output_slot;
+ }
+ }
+
+ private:
+ struct PairIntHash {
+ public:
+ std::size_t operator()(const std::pair<int, int>& x) const {
+ return std::hash<int>()(x.first) ^ std::hash<int>()(x.second);
+ }
+ };
+
+ struct TupleIntHash {
+ public:
+ std::size_t operator()(const std::tuple<int32, int32, int32>& x) const {
+ return std::hash<int32>()(get<0>(x)) ^ std::hash<int32>()(get<1>(x)) ^
+ std::hash<int32>()(get<2>(x));
+ }
+ };
+
+ int32 num_classes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CountExtremelyRandomStats").Device(DEVICE_CPU),
+ CountExtremelyRandomStats);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
new file mode 100644
index 0000000..804ed44
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
@@ -0,0 +1,112 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// FinishedNodes returns a 1-D tensor listing the nodes that are finished
+// accumulating.
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+using tensorforest::Sum;
+
+REGISTER_OP("FinishedNodes")
+ .Attr("num_split_after_samples: int32")
+ .Input("leaves: int32")
+ .Input("node_to_accumulator: int32")
+ .Input("pcw_total_splits: float")
+
+ .Output("finished: int32")
+ .Doc(R"doc(
+ Determines which of the given leaf nodes are done accumulating.
+
+ leaves:= A 1-d int32 tensor. Lists the nodes that are currently leaves.
+ node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
+ is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1.
+ pcw_total_splits: `pcw_total_splits[a][c]` records how many training examples
+ have class c and have ended up in the fertile node associated with
+ accumulator slot a. Between that and `pcw_candidate_splits`, the number of
+ examples taking the right branch of a split can be reconstructed.
+ finished:= A 1-d int32 tensor. Contains the nodes that have total split
+ counts greater or equal to the num_split_after_samples attribute.
+)doc");
+
+
+class FinishedNodes : public OpKernel {
+ public:
+ explicit FinishedNodes(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "num_split_after_samples", &num_split_after_samples_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& leaf_tensor = context->input(0);
+ const Tensor& node_to_accumulator = context->input(1);
+ const Tensor& pcw_total_splits = context->input(2);
+
+ OP_REQUIRES(context, leaf_tensor.shape().dims() == 1,
+ errors::InvalidArgument(
+ "leaf_tensor should be one-dimensional"));
+ OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+ errors::InvalidArgument(
+ "node_to_accumulator should be one-dimensional"));
+ OP_REQUIRES(context, pcw_total_splits.shape().dims() == 2,
+ errors::InvalidArgument(
+ "pcw_total_splits should be two-dimensional"));
+
+ const auto leaves = leaf_tensor.unaligned_flat<int32>();
+ const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+
+ const int32 num_leaves = leaf_tensor.shape().dim_size(0);
+
+ std::vector<int32> finished;
+ for (int i = 0; i < num_leaves; i++) {
+ const int32 leaf = leaves(i);
+ const int32 accumulator = node_map(leaf);
+ if (accumulator < 0) {
+ continue;
+ }
+
+ if (Sum<float>(pcw_total_splits.Slice(accumulator, accumulator + 1)) >=
+ num_split_after_samples_) {
+ finished.push_back(leaf);
+ }
+ }
+
+ // Copy to output.
+ Tensor* output_finished = nullptr;
+ TensorShape finished_shape;
+ finished_shape.AddDim(finished.size());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, finished_shape,
+ &output_finished));
+ auto out_finished = output_finished->unaligned_flat<int32>();
+
+ for (int32 i = 0; i < finished.size(); i++) {
+ out_finished(i) = finished[i];
+ }
+ }
+
+ private:
+ int32 num_split_after_samples_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU),
+ FinishedNodes);
+
+} // namespace tensorflow
+
diff --git a/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc b/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc
new file mode 100644
index 0000000..2fde74b
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc
@@ -0,0 +1,259 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// GrowTree adds children to the tree for finished nodes by using the
+// end_of_tree tensor as an indicator for where free nodes are in the
+// pre-allocated tree tensor.
+// For example if the tree is:
+// 1, -1, -1, -2, -2, -2, ...
+// Then end_of_tree should be 3 (the first -2, or "free" slot in the tensor).
+// If node 1 is now finished, the tree tensor after this op would be:
+// 1, 3, -1, -1, -1, -2, ...
+// and end_of_tree would be 5.
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+
+
+namespace tensorflow {
+
+using tensorforest::CHILDREN_INDEX;
+using tensorforest::FEATURE_INDEX;
+
+using tensorforest::LEAF_NODE;
+
+
+REGISTER_OP("GrowTree")
+ .Input("end_of_tree: int32")
+ .Input("tree_depths: int32")
+ .Input("node_to_accumulator: int32")
+ .Input("finished_nodes: int32")
+ .Input("best_splits: int32")
+ .Input("candidate_split_features: int32")
+ .Input("candidate_split_thresholds: float")
+ .Output("nodes_to_update: int32")
+ .Output("tree_updates: int32")
+ .Output("threshold_updates: float")
+ .Output("depth_updates: int32")
+ .Output("new_end_of_tree: int32")
+ .Doc(R"doc(
+ Output the tree changes needed to resolve fertile nodes.
+
+ Previous Ops have already decided which fertile nodes want to stop being
+ fertile and what their best candidate split should be and have passed that
+ information to this Op in `finished_nodes` and `best_splits`. This Op
+ merely checks that there is still space in tree to add new nodes, and if
+ so, writes out the sparse updates needed for the fertile nodes to be
+ resolved to the tree, threshold and depth tensors.
+
+ end_of_tree: `end_of_tree[0]` is the number of allocated nodes, or
+ equivalently the index of the first free node in the tree tensor.
+ tree_depths: `tree_depths[i]` is the depth in the tree of node i.
+ node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+ fertile node i, or -1 if node i isn't fertile.
+ finished_nodes:= A 1-d int32 tensor containing the indices of finished nodes.
+ best_splits: `best_splits[i]` is the index of the best split for
+ `finished_nodes[i]`.
+ candidate_split_features: `candidate_split_features[a][s]` is the feature
+ being considered for split s of the fertile node associated with
+ accumulator slot a.
+ candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
+ threshold value being considered for split s of the fertile node associated
+ with accumulator slot a.
+ nodes_to_update:= A 1-d int32 tensor containing the node indices that need
+ updating.
+ tree_updates: The updates to apply to the 2-d tree tensor. Intended to be
+ used with `tf.scatter_update(tree, nodes_to_update, tree_updates)`.
+ threshold_updates: The updates to apply to the 1-d thresholds tensor.
+ Intended to be used with
+ `tf.scatter_update(thresholds, nodes_to_update, threshold_updates)`.
+ depth_updates: The updates to apply to the 1-d depths tensor. Intended to
+ be used with `tf.scatter_update(depths, nodes_to_update, depth_updates)`.
+ new_end_of_tree: `new_end_of_tree[0]` is the new size of the tree.
+)doc");
+
+class GrowTree : public OpKernel {
+ public:
+ explicit GrowTree(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& end_of_tree = context->input(0);
+ const Tensor& tree_depths = context->input(1);
+ const Tensor& node_to_accumulator = context->input(2);
+ const Tensor& finished = context->input(3);
+ const Tensor& best_splits = context->input(4);
+ const Tensor& candidate_split_features = context->input(5);
+ const Tensor& candidate_split_thresholds = context->input(6);
+
+ OP_REQUIRES(context, end_of_tree.shape().dims() == 1,
+ errors::InvalidArgument(
+ "end_of_tree should be one-dimensional"));
+ OP_REQUIRES(context, tree_depths.shape().dims() == 1,
+ errors::InvalidArgument(
+ "tree_depths should be one-dimensional"));
+ OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+ errors::InvalidArgument(
+ "node_to_accumulator should be one-dimensional"));
+ OP_REQUIRES(context, finished.shape().dims() == 1,
+ errors::InvalidArgument(
+ "finished should be one-dimensional"));
+ OP_REQUIRES(context, best_splits.shape().dims() == 1,
+ errors::InvalidArgument(
+ "best_splits should be one-dimensional"));
+ OP_REQUIRES(context, candidate_split_features.shape().dims() == 2,
+ errors::InvalidArgument(
+ "candidate_split_features should be two-dimensional"));
+ OP_REQUIRES(context, candidate_split_thresholds.shape().dims() == 2,
+ errors::InvalidArgument(
+ "candidate_split_thresholds should be two-dimensional"));
+
+ OP_REQUIRES(
+ context,
+ finished.shape().dim_size(0) ==
+ best_splits.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of finished nodes should be the same in finished and "
+ "best_splits."));
+ OP_REQUIRES(
+ context,
+ tree_depths.shape().dim_size(0) ==
+ node_to_accumulator.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of nodes should be the same in tree_depths and "
+ "node_to_accumulator."));
+ OP_REQUIRES(
+ context,
+ candidate_split_features.shape().dim_size(0) ==
+ candidate_split_thresholds.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of accumulators should be the same in "
+ "candidate_split_features and candidate_split_thresholds."));
+ OP_REQUIRES(
+ context,
+ candidate_split_features.shape().dim_size(1) ==
+ candidate_split_thresholds.shape().dim_size(1),
+ errors::InvalidArgument(
+ "Number of splits should be the same in "
+ "candidate_split_features and candidate_split_thresholds."));
+
+ int32 current_end_of_tree = end_of_tree.unaligned_flat<int32>()(0);
+ const auto depths = tree_depths.unaligned_flat<int32>();
+ const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+ const auto finished_vec = finished.unaligned_flat<int32>();
+ const auto best_vec = best_splits.unaligned_flat<int32>();
+ const auto split_features = candidate_split_features.tensor<int32, 2>();
+ const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
+
+ const int32 num_finished = finished.shape().dim_size(0);
+ const int32 num_nodes = node_to_accumulator.shape().dim_size(0);
+
+ // Converting a leaf node into an internal node requires space for its
+ // two children.
+ int32 remaining_node_space = (num_nodes - current_end_of_tree) / 2;
+ int32 nodes_we_can_allocate = std::min(num_finished, remaining_node_space);
+ // Each conversion touches three nodes: the transitioning node and its
+ // two new children.
+ int32 num_updates = 3 * nodes_we_can_allocate;
+
+ Tensor* nodes_to_update_tensor = nullptr;
+ TensorShape nodes_to_update_shape;
+ nodes_to_update_shape.AddDim(num_updates);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, nodes_to_update_shape,
+ &nodes_to_update_tensor));
+ auto nodes_to_update_flat = nodes_to_update_tensor->tensor<int32, 1>();
+
+ Tensor* tree_updates_tensor = nullptr;
+ TensorShape tree_updates_shape;
+ tree_updates_shape.AddDim(num_updates);
+ tree_updates_shape.AddDim(2);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, tree_updates_shape,
+ &tree_updates_tensor));
+ auto tree_updates_flat = tree_updates_tensor->tensor<int32, 2>();
+
+ Tensor* threshold_updates_tensor = nullptr;
+ TensorShape threshold_updates_shape;
+ threshold_updates_shape.AddDim(num_updates);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, threshold_updates_shape,
+ &threshold_updates_tensor));
+ auto threshold_updates_flat = threshold_updates_tensor->tensor<float, 1>();
+
+ Tensor* depth_updates_tensor = nullptr;
+ TensorShape depth_updates_shape;
+ depth_updates_shape.AddDim(num_updates);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(3, depth_updates_shape,
+ &depth_updates_tensor));
+ auto depth_updates_flat = depth_updates_tensor->tensor<int32, 1>();
+
+ int output_slot = 0;
+ for (int i = 0; i < nodes_we_can_allocate; i++) {
+ const int32 node = finished_vec(i);
+ const int32 best = best_vec(i);
+ const int32 accumulator = node_map(node);
+ if (accumulator < 0) {
+ LOG(ERROR) << "Finished node doesn't have an accumulator.";
+ continue;
+ }
+
+ if (current_end_of_tree >= num_nodes - 1) {
+ LOG(ERROR) << "Could not grow tree any further.";
+ return;
+ }
+ const int32 left = current_end_of_tree;
+ nodes_to_update_flat(output_slot) = node;
+
+ tree_updates_flat(output_slot, CHILDREN_INDEX) = left;
+ tree_updates_flat(output_slot, FEATURE_INDEX) =
+ split_features(accumulator, best);
+ threshold_updates_flat(output_slot) = split_thresholds(accumulator, best);
+ depth_updates_flat(output_slot) = depths(node);
+ output_slot++;
+
+ nodes_to_update_flat(output_slot) = left;
+ tree_updates_flat(output_slot, CHILDREN_INDEX) = LEAF_NODE;
+ tree_updates_flat(output_slot, FEATURE_INDEX) = -1;
+ threshold_updates_flat(output_slot) = 0.0;
+ depth_updates_flat(output_slot) = depths(node) + 1;
+ output_slot++;
+
+ nodes_to_update_flat(output_slot) = left + 1;
+ tree_updates_flat(output_slot, CHILDREN_INDEX) = LEAF_NODE;
+ tree_updates_flat(output_slot, FEATURE_INDEX) = -1;
+ threshold_updates_flat(output_slot) = 0.0;
+ depth_updates_flat(output_slot) = depths(node) + 1;
+ output_slot++;
+
+ current_end_of_tree += 2;
+ }
+
+ Tensor* new_end_of_tree_tensor = nullptr;
+ TensorShape new_end_of_tree_shape;
+ new_end_of_tree_shape.AddDim(1);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(4, new_end_of_tree_shape,
+ &new_end_of_tree_tensor));
+ auto new_end_of_tree_flat = new_end_of_tree_tensor->tensor<int32, 1>();
+ new_end_of_tree_flat(0) = current_end_of_tree;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GrowTree").Device(DEVICE_CPU), GrowTree);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
new file mode 100644
index 0000000..452a683
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
@@ -0,0 +1,241 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// SampleInputs initializes candidate splits/threshold values randomly
+// from incoming data for not-yet-initialized fertile nodes.
+#include <ctime>
+#include <unordered_map>
+#include <set>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+using tensorforest::IsAllInitialized;
+
+
+// TODO(gilberth): Reinitialize candidate splits that were finished last round.
+REGISTER_OP("SampleInputs")
+ .Attr("split_initializations_per_input: int32")
+ .Attr("split_sampling_random_seed: int32")
+ .Input("input_data: float")
+ .Input("node_to_accumulator: int32")
+ .Input("leaves: int32")
+ .Input("candidate_split_features: int32")
+ .Input("candidate_split_thresholds: float")
+ .Output("accumulators_to_update: int32")
+ .Output("new_split_feature_rows: int32")
+ .Output("new_split_threshold_rows: float")
+ .Doc(R"doc(
+ Initializes candidate splits for newly fertile nodes.
+
+ In an extremely random forest, we don't consider all possible threshold
+ values for a candidate split feature, but rather only a sampling of them.
+ This Op takes those samples from the training data in `input_data`. The
+ feature and threshold samples are stored in tensors that are indexed by
+ accumulator slot, so for each input, we must first look up which leaf
+ it ended up in (using `leaves`) and then which accumulator slot if any
+ that leaf maps to (using `node_to_accumulator`).
+
+ The attribute `split_initializations_per_input` controls how many splits
+ a single training example can initialize, and the attribute
+ `split_sampling_random_seed` sets the random number generator's seed
+ (a value of 0 means use the current time as the seed).
+
+ input_data: The features for the current batch of training data.
+ `input_data[i][j]` is the j-th feature of the i-th input.
+ node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the
+ associated accumulator slot. For non-fertile nodes, it is -1.
+ leaves: `leaves[i]` is the leaf that the i-th input landed in, as
+ calculated by CountExtremelyRandomStats.
+ candidate_split_features: The current features for the candidate splits;
+ `candidate_split_features[a][s]` is the index of the feature being
+ considered by split s in accumulator slot a.
+ candidate_split_thresholds: The current thresholds for the candidate splits;
+ `candidate_split_thresholds[a][s]` is the threshold value being
+ considered by split s in accumulator slot a.
+ accumulators_to_update: A list of the accumulators to change in the
+ candidate_split_features and candidate_split_thresholds tensors.
+ new_split_feature_rows: The new values for the candidate_split_features
+ tensor. Intended to be used with
+ `tf.scatter_update(candidate_split_features,
+ accumulators_to_update,
+ new_split_feature_rows)`
+ new_split_threshold_rows: The new values for the candidate_split_thresholds
+ tensor. Intended to be used with
+ `tf.scatter_update(candidate_split_thresholds,
+ accumulators_to_update,
+ new_split_feature_thresholds)`
+)doc");
+
+class SampleInputs : public OpKernel {
+ public:
+ explicit SampleInputs(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "split_initializations_per_input", &split_initializations_per_input_));
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "split_sampling_random_seed", &split_sampling_random_seed_));
+ // Set up the random number generator.
+ if (split_sampling_random_seed_ == 0) {
+ uint64 time_seed = static_cast<uint64>(std::time(NULL));
+ single_rand_ = std::unique_ptr<random::PhiloxRandom>(
+ new random::PhiloxRandom(time_seed));
+ } else {
+ single_rand_ = std::unique_ptr<random::PhiloxRandom>(
+ new random::PhiloxRandom(split_sampling_random_seed_));
+ }
+
+ rng_ = std::unique_ptr<random::SimplePhilox>(
+ new random::SimplePhilox(single_rand_.get()));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_data = context->input(0);
+ const Tensor& node_to_accumulator = context->input(1);
+ const Tensor& leaves = context->input(2);
+ const Tensor& split_features = context->input(3);
+ const Tensor& split_thresholds = context->input(4);
+
+ OP_REQUIRES(context, input_data.shape().dims() == 2,
+ errors::InvalidArgument(
+ "input_data should be two-dimensional"));
+ OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+ errors::InvalidArgument(
+ "node_to_accumulator should be one-dimensional"));
+ OP_REQUIRES(context, leaves.shape().dims() == 1,
+ errors::InvalidArgument(
+ "leaves should be one-dimensional"));
+ OP_REQUIRES(context, split_features.shape().dims() == 2,
+ errors::InvalidArgument(
+ "split_features should be two-dimensional"));
+ OP_REQUIRES(context, split_thresholds.shape().dims() == 2,
+ errors::InvalidArgument(
+ "split_thresholds should be two-dimensional"));
+
+ OP_REQUIRES(
+ context,
+ split_features.shape() == split_thresholds.shape(),
+ errors::InvalidArgument(
+ "split_features and split_thresholds should be the same shape."));
+
+ const auto inputs = input_data.tensor<float, 2>();
+ const auto leaves_vec = leaves.unaligned_flat<int32>();
+ const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+ const auto features = split_features.tensor<int32, 2>();
+ const auto thresholds = split_thresholds.tensor<float, 2>();
+
+ const int32 num_data = leaves.shape().dim_size(0);
+ const int32 num_splits = split_features.shape().dim_size(1);
+ const int32 num_features = input_data.shape().dim_size(1);
+
+ std::unordered_map<int32, std::set<int32>> accumulator_to_leaves;
+
+ // The first pass just calculates num_output_accumulators.
+ for (int i = 0; i < num_data; i++) {
+ const int32 leaf = leaves_vec(i);
+ const int32 accumulator = node_map(leaf);
+ // Check for non-fertile node or fertile node that is already
+ // initialized.
+ if (accumulator >= 0 &&
+ !IsAllInitialized(
+ split_features.Slice(accumulator, accumulator + 1))) {
+ accumulator_to_leaves[accumulator].insert(i);
+ }
+ }
+
+ // Now we can allocate the outputs.
+ int num_output_accumulators = accumulator_to_leaves.size();
+ VLOG(1) << "num output accumulators = " << num_output_accumulators;
+ Tensor* accumulators_tensor = nullptr;
+ TensorShape accumulators_shape;
+ accumulators_shape.AddDim(num_output_accumulators);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, accumulators_shape,
+ &accumulators_tensor));
+ auto accumulators_flat = accumulators_tensor->tensor<int32, 1>();
+
+ Tensor* new_split_feature_rows_tensor = nullptr;
+ TensorShape new_split_feature_rows_shape;
+ new_split_feature_rows_shape.AddDim(num_output_accumulators);
+ new_split_feature_rows_shape.AddDim(num_splits);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, new_split_feature_rows_shape,
+ &new_split_feature_rows_tensor));
+ auto new_split_feature_rows_flat =
+ new_split_feature_rows_tensor->tensor<int32, 2>();
+
+ Tensor* new_split_threshold_rows_tensor = nullptr;
+ TensorShape new_split_threshold_rows_shape;
+ new_split_threshold_rows_shape.AddDim(num_output_accumulators);
+ new_split_threshold_rows_shape.AddDim(num_splits);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, new_split_threshold_rows_shape,
+ &new_split_threshold_rows_tensor));
+ auto new_split_threshold_rows_flat =
+ new_split_threshold_rows_tensor->tensor<float, 2>();
+
+ // The second pass fills out the outputs.
+ int output_slot = 0;
+ for (const auto& active : accumulator_to_leaves) {
+ const int32 accumulator = active.first;
+ const std::set<int32> inputs_for_accumulator = active.second;
+ VLOG(1) << "Accumulator " << accumulator
+ << " gets new output slot " << output_slot;
+ accumulators_flat(output_slot) = accumulator;
+
+ // scatter_update updates entire rows, so we first copy the existing
+ // rows into the output tensors, and then write over the values we
+ // want to change.
+ for (int split = 0; split < num_splits; split++) {
+ new_split_feature_rows_flat(output_slot, split) =
+ features(accumulator, split);
+ new_split_threshold_rows_flat(output_slot, split) =
+ thresholds(accumulator, split);
+ }
+
+ for (const int32 i : inputs_for_accumulator) {
+ VLOG(2) << "Looking at data # " << i;
+
+ int32 num_inits = split_initializations_per_input_;
+ for (int split = 0; split < num_splits && num_inits > 0; split++) {
+ if (new_split_feature_rows_flat(output_slot, split) < 0) {
+ VLOG(1) << "Over-writing @ " << output_slot << "," << split;
+ const int32 index = rng_->Uniform(num_features);
+ new_split_feature_rows_flat(output_slot, split) = index;
+ new_split_threshold_rows_flat(output_slot, split) =
+ inputs(i, index);
+ --num_inits;
+ }
+ }
+ }
+ ++output_slot;
+ }
+ }
+
+ private:
+ int32 split_initializations_per_input_;
+ int32 split_sampling_random_seed_;
+ std::unique_ptr<random::PhiloxRandom> single_rand_;
+ std::unique_ptr<random::SimplePhilox> rng_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("SampleInputs").Device(DEVICE_CPU), SampleInputs);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc b/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc
new file mode 100644
index 0000000..a65a2c0
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc
@@ -0,0 +1,103 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// ScatterAddNdim implements a scatter_add that can operate on sparse
+// updates without being limited to the first dimension for indices.
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+
+
+namespace tensorflow {
+
+REGISTER_OP("ScatterAddNdim")
+ .Input("input: Ref(float)")
+ .Input("indices: int32")
+ .Input("deltas: float")
+
+ .Doc(R"doc(
+ Add elements in deltas to mutable input according to indices.
+
+ input: A N-dimensional float tensor to mutate.
+ indices:= A 2-D int32 tensor. The size of dimension 0 is the number of
+ deltas, the size of dimension 1 is the rank of the input. `indices[i]`
+ gives the coordinates of input that `deltas[i]` should add to
+ deltas: `deltas[i]` is the value to add to input at index indices[i][:]
+)doc");
+
+
+class ScatterAddNdim : public OpKernel {
+ public:
+ explicit ScatterAddNdim(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ Tensor input_tensor = context->mutable_input(0, false);
+ const Tensor& indices_tensor = context->input(1);
+ const Tensor& deltas_tensor = context->input(2);
+
+ OP_REQUIRES(context, deltas_tensor.shape().dims() == 1,
+ errors::InvalidArgument(
+ "deltas should be one-dimensional"));
+ if (indices_tensor.shape().dim_size(0) > 0) {
+ OP_REQUIRES(context, indices_tensor.shape().dims() == 2,
+ errors::InvalidArgument(
+ "indices should be two-dimensional"));
+ OP_REQUIRES(
+ context,
+ indices_tensor.shape().dim_size(1) == input_tensor.shape().dims(),
+ errors::InvalidArgument(
+ "Number of indices dimensions should be the same as input "
+ "rank."));
+ OP_REQUIRES(
+ context,
+ indices_tensor.shape().dim_size(0) ==
+ deltas_tensor.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of updates should be same as number of indices."));
+ } else {
+ return;
+ }
+
+ auto input = input_tensor.flat<float>();
+
+ const auto indices = indices_tensor.tensor<int32, 2>();
+ const auto deltas = deltas_tensor.unaligned_flat<float>();
+
+ const int32 num_dims = indices_tensor.shape().dim_size(1);
+
+ // Calculate index multipliers.
+ std::vector<int32> multipliers;
+ int32 last_size = input.size();
+
+ for (int32 j = 0; j < num_dims; j++) {
+ const int32 m = last_size / input_tensor.shape().dim_size(j);
+ multipliers.push_back(m);
+ last_size = m;
+ }
+
+ // Perform updates.
+ for (int32 i = 0; i < indices_tensor.shape().dim_size(0); i++) {
+ int32 index = 0;
+ for (int32 j = 0; j < num_dims; j++) {
+ index += indices(i, j) * multipliers[j];
+ }
+ input(index) += deltas(i);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ScatterAddNdim").Device(DEVICE_CPU),
+ ScatterAddNdim);
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
new file mode 100644
index 0000000..3e84534
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
@@ -0,0 +1,170 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// TreePredictions returns the per-class probabilities for each input by
+// evaluating the given tree.
+#include <algorithm>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+using tensorforest::CHILDREN_INDEX;
+using tensorforest::FEATURE_INDEX;
+using tensorforest::LEAF_NODE;
+using tensorforest::FREE_NODE;
+
+using tensorforest::DecideNode;
+using tensorforest::Sum;
+
+REGISTER_OP("TreePredictions")
+ .Attr("valid_leaf_threshold: float")
+ .Input("input_data: float")
+ .Input("tree: int32")
+ .Input("tree_thresholds: float")
+ .Input("node_per_class_weights: float")
+
+ .Output("predictions: float")
+ .Doc(R"doc(
+ Returns the per-class probabilities for each input.
+
+ input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+ gives the j-th feature of the i-th input.
+ tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
+ of the i-th node, `tree[0][i] + 1` gives the index of the right child of
+ the i-th node, and `tree[1][i]` gives the index of the feature used to
+ split the i-th node.
+ tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
+ node.
+ node_per_class_weights: `node_per_class_weights[n][c]` records how many
+ training examples have class c and have ended up in node n.
+ predictions: `predictions[i][j]` is the probability that input i is class j.
+ valid_leaf_threshold: Minimum number of samples that have arrived to a leaf
+ to be considered a valid leaf, otherwise use the parent.
+)doc");
+
+
+class TreePredictions : public OpKernel {
+ public:
+ explicit TreePredictions(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "valid_leaf_threshold", &valid_leaf_threshold_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_data = context->input(0);
+
+ const Tensor& tree_tensor = context->input(1);
+ const Tensor& tree_thresholds = context->input(2);
+ const Tensor& node_per_class_weights = context->input(3);
+
+ OP_REQUIRES(context, tree_tensor.shape().dims() == 2,
+ errors::InvalidArgument(
+ "tree should be two-dimensional"));
+ OP_REQUIRES(context, tree_thresholds.shape().dims() == 1,
+ errors::InvalidArgument(
+ "tree_threhsolds should be one-dimensional"));
+ OP_REQUIRES(context, node_per_class_weights.shape().dims() == 2,
+ errors::InvalidArgument(
+ "node_pcw should be two-dimensional"));
+
+ if (input_data.shape().dim_size(0) > 0) {
+ OP_REQUIRES(context, input_data.shape().dims() == 2,
+ errors::InvalidArgument(
+ "input_data should be two-dimensional"));
+ }
+ OP_REQUIRES(
+ context,
+ tree_tensor.shape().dim_size(0) ==
+ tree_thresholds.shape().dim_size(0) &&
+ tree_tensor.shape().dim_size(0) ==
+ node_per_class_weights.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of nodes should be the same in "
+ "tree, tree_thresholds and node_pcw."));
+
+ const int32 num_classes = node_per_class_weights.shape().dim_size(1);
+ const int32 num_data = input_data.shape().dim_size(0);
+
+ Tensor* output_predictions = nullptr;
+ TensorShape output_shape;
+ output_shape.AddDim(num_data);
+ output_shape.AddDim(num_classes);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape,
+ &output_predictions));
+ auto out = output_predictions->tensor<float, 2>();
+
+ const auto node_pcw = node_per_class_weights.tensor<float, 2>();
+ const auto tree = tree_tensor.tensor<int32, 2>();
+ const auto thresholds = tree_thresholds.unaligned_flat<float>();
+
+ for (int i = 0; i < num_data; i++) {
+ const Tensor point = input_data.Slice(i, i+1);
+ int node_index = 0;
+ int parent = -1;
+ while (true) {
+ const int32 left_child = tree(node_index, CHILDREN_INDEX);
+ if (left_child == LEAF_NODE) {
+ float sum = Sum<float>(node_per_class_weights.Slice(
+ node_index, node_index + 1));
+ float parent_weight = 0.0;
+ if (sum < valid_leaf_threshold_ && parent >= 0) {
+ VLOG(1) << "not enough samples at leaf, including parent counts."
+ << "child sum = " << sum;
+ float parent_sum = Sum<float>(node_per_class_weights.Slice(
+ parent, parent + 1));
+ // Weight the parent's counts just enough so that the new sum is
+ // valid_leaf_threshold_, but never give any counts a weight of
+ // more than 1.
+ parent_weight = std::min(1.0f,
+ (valid_leaf_threshold_ - sum) / parent_sum);
+ sum += parent_weight * parent_sum;
+ VLOG(1) << "Sum w/ parent included = " << sum;
+ }
+ for (int c = 0; c < num_classes; c++) {
+ float w = node_pcw(node_index, c);
+ if (parent_weight > 0.0) {
+ w += parent_weight * node_pcw(parent, c);
+ }
+ out(i, c) = w / sum;
+ }
+ break;
+ } else if (left_child == FREE_NODE) {
+ LOG(ERROR) << "Reached a free node, not good.";
+ return;
+ }
+ parent = node_index;
+ node_index = left_child +
+ DecideNode(point, tree(node_index, FEATURE_INDEX),
+ thresholds(node_index));
+ }
+ }
+
+ VLOG(1) << "tree: " << tree;
+ VLOG(1) << "output: " << out;
+ }
+
+ private:
+ float valid_leaf_threshold_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("TreePredictions").Device(DEVICE_CPU),
+ TreePredictions);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
new file mode 100644
index 0000000..df50a3d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
@@ -0,0 +1,67 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+using tensorflow::Tensor;
+
+int32 BestFeature(const Tensor& total_counts, const Tensor& split_counts,
+ int32 accumulator) {
+ int32 best_feature_index = -1;
+ // We choose the split with the lowest score.
+ float best_score = kint64max;
+ const int32 num_splits = split_counts.shape().dim_size(1);
+ const int32 num_classes = split_counts.shape().dim_size(2);
+ // Ideally, Eigen::Tensor::chip would be best to use here but it results
+ // in seg faults, so we have to go with flat views of these tensors. However,
+ // it is still pretty efficient because we put off evaluation until the
+ // score is actually returned.
+ const auto tc = total_counts.Slice(
+ accumulator, accumulator + 1).unaligned_flat<float>();
+ const auto splits = split_counts.Slice(
+ accumulator, accumulator + 1).unaligned_flat<float>();
+ Eigen::array<int, 1> bcast({num_splits});
+ const auto rights = tc.broadcast(bcast) - splits;
+
+ for (int i = 0; i < num_splits; i++) {
+ Eigen::array<int, 1> offsets = {i * num_classes};
+ Eigen::array<int, 1> extents = {num_classes};
+ float score = WeightedGiniImpurity(splits.slice(offsets, extents)) +
+ WeightedGiniImpurity(rights.slice(offsets, extents));
+
+ if (score < best_score) {
+ best_score = score;
+ best_feature_index = i;
+ }
+ }
+ return best_feature_index;
+}
+
+bool DecideNode(const Tensor& point, int32 feature, float bias) {
+ const auto p = point.unaligned_flat<float>();
+ return p(feature) > bias;
+}
+
+bool IsAllInitialized(const Tensor& features) {
+ const auto feature_vec = features.unaligned_flat<int32>();
+ return feature_vec(feature_vec.size() - 1) >= 0;
+}
+
+
+} // namespace tensorforest
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
new file mode 100644
index 0000000..9b7553e
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
@@ -0,0 +1,89 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef LEARNING_LIB_TENSOR_FOREST_V2_TREE_UTILS_H_
+#define LEARNING_LIB_TENSOR_FOREST_V2_TREE_UTILS_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+// Indexes in the tree representation's 2nd dimension for children and features.
+const int32 CHILDREN_INDEX = 0;
+const int32 FEATURE_INDEX = 1;
+
+// Used in the tree's children sub-tensor to indicate leaf and free nodes.
+const int32 LEAF_NODE = -1;
+const int32 FREE_NODE = -2;
+
+// Calculates the sum of a tensor.
+template<typename T>
+T Sum(tensorflow::Tensor counts) {
+ Eigen::Tensor<T, 0, Eigen::RowMajor> count_sum =
+ counts.unaligned_flat<T>().sum();
+ return count_sum(0);
+}
+
+// Given an Eigen::Tensor type, calculate the Gini impurity, which we use
+// to determine the best split (lowest) and which nodes to allocate first
+// (highest).
+template<typename T>
+int32 WeightedGiniImpurity(const T& counts) {
+ // Our split score is the Gini impurity times the number of examples
+ // seen by the leaf. If c(i) denotes the i-th class count and c = sum_i c(i)
+ // then
+ // score = c * (1 - sum_i ( c(i) / c )^2 )
+ // = c - sum_i c(i)^2 / c
+ const auto smoothed = counts + counts.constant(1.0f);
+ const auto sum = smoothed.sum();
+ const auto sum2 = smoothed.square().sum();
+ Eigen::Tensor<float, 0, Eigen::RowMajor> ret = sum - (sum2 / sum);
+ return ret(0);
+}
+
+// Returns the best split to use based on the (lowest) Gini impurity.
+// Takes in the whole total and per-split count tensors because using
+// Tensor::Slice returns a tensor of the same dimensionality, which makes
+// things a little awkward.
+// TODO(gilberth): Currently test_util.BestFeatureToSplit doesn't work with
+// this code because the shapes of the incoming tensors are different than
+// in v1. Try to make it work for both versions?
+int32 BestFeature(const tensorflow::Tensor& total_counts,
+ const tensorflow::Tensor& split_counts,
+ int32 accumulator);
+
+// Initializes everything in the given tensor to the given value.
+template <typename T>
+void Initialize(tensorflow::Tensor counts, T val = 0) {
+ auto flat = counts.unaligned_flat<T>();
+ std::fill(flat.data(), flat.data() + flat.size(), val);
+}
+
+// Returns true if the point falls to the right (i.e., the selected feature
+// of the input point is greater than the bias threshold), and false if it
+// falls to the left.
+bool DecideNode(const tensorflow::Tensor& point, int32 feature, float bias);
+
+// Returns true if all the splits are initialized. Since they get initialized
+// in order, we can simply infer this from the last split.
+// This should only be called for a single allocator's candidate features
+// (i.e. candidate_split_features.Slice(accumulator, accumulator + 1) ).
+bool IsAllInitialized(const tensorflow::Tensor& features);
+
+} // namespace tensorforest
+} // namespace tensorflow
+
+#endif // LEARNING_LIB_TENSOR_FOREST_V2_TREE_UTILS_H_
diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
new file mode 100644
index 0000000..0bf7525
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
@@ -0,0 +1,407 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// UpdateFertileSlots manages accumulator slots. It assigns free or newly
+// finished accumulator slots to waiting non-fertile nodes and new leaves
+// according to their existing split scores (based on node pcws). It does not
+// allocate slots to leaves that are beyond max depth.
+#include <unordered_map>
+#include <set>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/top_n.h"
+
+
+namespace tensorflow {
+
+using gtl::TopN;
+using tensorforest::Initialize;
+using tensorforest::WeightedGiniImpurity;
+
+
+REGISTER_OP("UpdateFertileSlots")
+ .Attr("max_depth: int32")
+ .Input("finished: int32")
+ .Input("non_fertile_leaves: int32")
+ .Input("non_fertile_leaf_scores: float")
+ .Input("end_of_tree: int32")
+ .Input("tree_depths: int32")
+ .Input("pcw_candidate_splits: float")
+ .Input("pcw_total_splits: float")
+ .Input("node_to_accumulator: int32")
+ .Output("node_map_updates: int32")
+ .Output("accumulators_cleared: int32")
+ .Output("accumulators_allocated: int32")
+ .Output("new_nonfertile_leaves: int32")
+ .Output("new_nonfertile_leaves_scores: float")
+ .Doc(R"doc(
+ Updates accumulator slots to reflect finished or newly fertile nodes.
+
+ Leaves at the depth of the attribute `max_depth` won't be made fertile
+ (i.e., won't be given an accumulator slot.)
+
+ finished:= A 1-d int32 tensor containing the indices of fertile nodes that
+ are ready to decide on a split.
+ non_fertile_leaves:= A 1-d int32 tensor containing the indices of all the
+ currently non-fertile leaves. If there are free accumulator slots after
+ deallocation, UpdateFertileSlots will consider these nodes (plus the ones
+ in new_leaves) and potentially turn some of them fertile.
+ non_fertile_leaf_scores: `non_fertile_leaf_scores[i]` is the splitting score
+ of the non-fertile leaf `non_fertile_leaves[i]`.
+ end_of_tree: The end of tree tensor from the previous training iteration, used
+ with the finished input to calculate a list of new leaf indices created by
+ GrowTree, which will be considered to become fertile if there are free
+ slots.
+ tree_depths: `tree_depths[i]` is the depth in the tree of node i.
+ pcw_candidate_splits: `pcw_candidate_splits[a][s][c]` records how many
+ training examples have class c and have ended up in the fertile node
+ associated with accumulator slot a and then taken the *left* branch of
+ candidate split s.
+ pcw_total_splits: `pcw_total_splits[a][c]` records how many training examples
+ have class c and have ended up in the fertile node associated with
+ accumulator slot a. Between that and `pcw_candidate_splits`, the number of
+ examples taking the right branch of a split can be reconstructed.
+ node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+ fertile node i, or -1 if node i isn't fertile.
+ node_map_updates:= A 2-d int32 tensor describing the changes that need to
+ be applied to the node_to_accumulator map. Intended to be used with
+ `tf.scatter_update(node_to_accumulator,
+ node_map_updates[0],
+ node_map_updates[1])`.
+ accumulators_cleared:= A 1-d int32 tensor containing the indices of all
+ the accumulator slots that need to be cleared.
+ accumulators_allocated:= A 1-d int32 tensor containing the indices of all
+ the accumulator slots that need to be allocated.
+ new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the
+ leaves that are now non-fertile.
+ new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the
+ splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`.
+)doc");
+
+
+class UpdateFertileSlots : public OpKernel {
+ public:
+ explicit UpdateFertileSlots(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "max_depth", &max_depth_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& finished = context->input(0);
+
+ const Tensor& non_fertile_leaves = context->input(1);
+ const Tensor& non_fertile_leaf_scores = context->input(2);
+ const Tensor& end_of_tree = context->input(3);
+ const Tensor& tree_depths = context->input(4);
+
+ const Tensor& pcw_candidate_splits = context->input(5);
+ const Tensor& pcw_total_splits = context->input(6);
+ const Tensor& node_to_accumulator = context->input(7);
+
+ OP_REQUIRES(context, finished.shape().dims() == 1,
+ errors::InvalidArgument(
+ "finished should be one-dimensional"));
+ OP_REQUIRES(context, non_fertile_leaves.shape().dims() == 1,
+ errors::InvalidArgument(
+ "non_fertile_leaves should be one-dimensional"));
+ OP_REQUIRES(context, non_fertile_leaf_scores.shape().dims() == 1,
+ errors::InvalidArgument(
+ "non_fertile_leaves_scores should be one-dimensional"));
+ OP_REQUIRES(context, end_of_tree.shape().dims() == 1,
+ errors::InvalidArgument(
+ "end_of_tree should be one-dimensional"));
+ OP_REQUIRES(context, tree_depths.shape().dims() == 1,
+ errors::InvalidArgument(
+ "tree_depths should be one-dimensional"));
+ OP_REQUIRES(context, pcw_candidate_splits.shape().dims() == 3,
+ errors::InvalidArgument(
+ "pcw_candidate_splits should be three-dimensional"));
+ OP_REQUIRES(context, pcw_total_splits.shape().dims() == 2,
+ errors::InvalidArgument(
+ "pcw_total_splits should be two-dimensional"));
+ OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+ errors::InvalidArgument(
+ "node_to_accumulator should be one-dimensional"));
+
+ OP_REQUIRES(
+ context,
+ pcw_candidate_splits.shape().dim_size(0) ==
+ pcw_total_splits.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of accumulators should be the same in pcw_candidate_splits "
+ "and pcw_total_splits."));
+ OP_REQUIRES(
+ context,
+ non_fertile_leaves.shape().dim_size(0) ==
+ non_fertile_leaf_scores.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Number of non fertile leaves should be the same in "
+ "non_fertile_leaves and non_fertile_leaf_scores."));
+
+ // Read finished accumulators into a set for quick lookup.
+ const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+ const auto finished_vec = finished.unaligned_flat<int32>();
+ std::set<int32> finished_accumulators;
+ for (int32 i = 0; i < finished_vec.size(); ++i) {
+ finished_accumulators.insert(node_map(finished_vec(i)));
+ }
+
+ // Construct leaf heap to sort leaves to allocate accumulators to.
+ const auto eot = end_of_tree.unaligned_flat<int32>();
+ const int32 num_nodes = tree_depths.shape().dim_size(0);
+ const int32 num_finished = finished.shape().dim_size(0);
+ const int32 num_new_leaves = std::min(num_finished * 2, num_nodes - eot(0));
+
+ LeafHeapType leaf_heap(
+ non_fertile_leaves.shape().dim_size(0) +
+ num_new_leaves, OrderBySecondGreater());
+ ConstructLeafHeap(
+ non_fertile_leaves, non_fertile_leaf_scores, tree_depths,
+ eot(0), num_new_leaves, pcw_total_splits.shape().dim_size(1),
+ &leaf_heap);
+
+ // Allocate leaves.
+ std::unique_ptr<HeapValuesType> values(
+ leaf_heap.Extract());
+ int32 accumulator = -1; // This will first get incremented to 0.
+ int32 num_accumulators_allocated = 0;
+ std::unordered_map<int32, int32> accumulators_to_node;
+ FindNextAccumulator(pcw_total_splits, finished_accumulators, &accumulator);
+ int i = 0;
+ for (; i < values->size(); ++i) {
+ const std::pair<int32, float>& node = (*values)[i];
+ if (accumulator < 0) {
+ VLOG(1) << "No allocators left.";
+ break;
+ }
+ VLOG(1) << "setting node " << node.first << " to accumulator "
+ << accumulator;
+ ++num_accumulators_allocated;
+ accumulators_to_node[accumulator] = node.first;
+
+ FindNextAccumulator(pcw_total_splits, finished_accumulators,
+ &accumulator);
+ }
+
+ // Construct and fill outputs.
+ SetNodeMapUpdates(accumulators_to_node, finished, context);
+ SetAccumulatorsCleared(finished_accumulators,
+ accumulators_to_node, context);
+ SetAccumulatorsAllocated(accumulators_to_node, context);
+ SetNewNonFertileLeaves(values.get(), i, context);
+ }
+
+ private:
+ struct OrderBySecondGreater {
+ bool operator()(const std::pair<int32, float> &left,
+ const std::pair<int32, float> &right) {
+ return left.second > right.second;
+ }
+ };
+
+ typedef TopN<std::pair<int32, float>, OrderBySecondGreater> LeafHeapType;
+ typedef std::vector<std::pair<int32, float>> HeapValuesType;
+
+ // Creates an update tensor for node to accumulator map. Sets finished nodes
+ // to -1 (no accumulator assigned) and newly allocated nodes to their
+ // accumulator.
+ void SetNodeMapUpdates(
+ const std::unordered_map<int32, int32>& accumulators_to_node,
+ const Tensor& finished, OpKernelContext* context) {
+ // Node map updates.
+ Tensor* output_node_map = nullptr;
+ TensorShape node_map_shape;
+ node_map_shape.AddDim(2);
+ node_map_shape.AddDim(accumulators_to_node.size() +
+ finished.shape().dim_size(0));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, node_map_shape,
+ &output_node_map));
+
+ auto out_node = output_node_map->tensor<int32, 2>();
+ int32 output_slot = 0;
+
+ // Set finished nodes to -1.
+ const auto finished_vec = finished.unaligned_flat<int32>();
+ for (int32 i = 0; i < finished_vec.size(); ++i) {
+ out_node(0, output_slot) = finished_vec(i);
+ out_node(1, output_slot) = -1;
+ ++output_slot;
+ }
+
+ // Set newly allocated nodes to their allocator.
+ for (const auto& node_alloc_pair : accumulators_to_node) {
+ out_node(0, output_slot) = node_alloc_pair.second;
+ out_node(1, output_slot) = node_alloc_pair.first;
+ ++output_slot;
+ }
+ }
+
+ // Creates output tensor for cleared accumulators. Cleared accumulators are
+ // those that were finished but not re-allocated.
+ void SetAccumulatorsCleared(
+ const std::set<int32>& finished_accumulators,
+ const std::unordered_map<int32, int32>& accumulators_to_node,
+ OpKernelContext* context) {
+ std::set<int32> cleared;
+ for (const int32 node : finished_accumulators) {
+ if (accumulators_to_node.find(node) == accumulators_to_node.end()) {
+ cleared.insert(node);
+ }
+ }
+
+ Tensor* output_cleared = nullptr;
+ TensorShape cleared_shape;
+ cleared_shape.AddDim(cleared.size());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, cleared_shape,
+ &output_cleared));
+
+ auto out = output_cleared->unaligned_flat<int32>();
+
+ int32 i = 0;
+ for (const int32 accumulator : cleared) {
+ out(i) = accumulator;
+ ++i;
+ }
+ }
+
+ // Creates output tensor for accumulators that were allocated to now-fertile
+ // nodes.
+ void SetAccumulatorsAllocated(
+ const std::unordered_map<int32, int32>& accumulators_to_node,
+ OpKernelContext* context) {
+ // Node map updates.
+ Tensor* output_allocated = nullptr;
+ TensorShape allocated_shape;
+ allocated_shape.AddDim(accumulators_to_node.size());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, allocated_shape,
+ &output_allocated));
+
+ auto out = output_allocated->unaligned_flat<int32>();
+ int32 output_slot = 0;
+
+ // Set newly allocated nodes to their allocator.
+ for (const auto& node_alloc_pair : accumulators_to_node) {
+ out(output_slot) = node_alloc_pair.first;
+ ++output_slot;
+ }
+ }
+
+ // Creates output tensors for non-fertile leaves and non-fertile leaf scores.
+ // Start indicates the index in values where the leaves that weren't
+ // allocated this round begin, and should thus be placed in the new
+ // nonfertile_leaves tensors.
+ void SetNewNonFertileLeaves(HeapValuesType* values, int start,
+ OpKernelContext* context) {
+ // Node map updates.
+ int32 num_values = values->size() - start;
+
+ // Unfortunately, a zero-sized Variable results in an uninitialized
+ // error, probably because they check for zero size instead of
+ // a real inititalization condition.
+ bool fill_with_garbage = false;
+ if (num_values == 0) {
+ num_values = 1;
+ fill_with_garbage = true;
+ }
+ Tensor* output_nonfertile_leaves = nullptr;
+ TensorShape nonfertile_leaves_shape;
+ nonfertile_leaves_shape.AddDim(num_values);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(3, nonfertile_leaves_shape,
+ &output_nonfertile_leaves));
+
+ auto out_nonfertile_leaves =
+ output_nonfertile_leaves->unaligned_flat<int32>();
+
+ Tensor* output_nonfertile_leaves_scores = nullptr;
+ TensorShape nonfertile_leaves_scores_shape;
+ nonfertile_leaves_scores_shape.AddDim(num_values);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(4, nonfertile_leaves_scores_shape,
+ &output_nonfertile_leaves_scores));
+
+ auto out_nonfertile_leaves_scores =
+ output_nonfertile_leaves_scores->unaligned_flat<float>();
+
+ if (fill_with_garbage) {
+ out_nonfertile_leaves(0) = -1;
+ out_nonfertile_leaves_scores(0) = 0.0;
+ return;
+ }
+
+ for (int32 i = start; i < values->size(); ++i) {
+ const std::pair<int32, float>& node = (*values)[i];
+ out_nonfertile_leaves(i -start) = node.first;
+ out_nonfertile_leaves_scores(i - start) = node.second;
+ }
+ }
+
+ void ConstructLeafHeap(
+ const Tensor& non_fertile_leaves, const Tensor& non_fertile_leaf_scores,
+ const Tensor& tree_depths, int32 end_of_tree, int32 num_new_leaves,
+ int32 num_classes, LeafHeapType* leaf_heap) {
+ const auto leaf_vec = non_fertile_leaves.unaligned_flat<int32>();
+ const auto leaf_score_vec = non_fertile_leaf_scores.unaligned_flat<float>();
+ const auto depths = tree_depths.unaligned_flat<int32>();
+
+ for (int i = 0; i < leaf_vec.size(); i++) {
+ const int32 leaf = leaf_vec(i);
+ // Filter out leaves < 0, non_fertile_nodes can contain garbage at
+ // startup.
+ if (leaf >= 0 && depths(leaf) < max_depth_) {
+ leaf_heap->push(std::make_pair(leaf, leaf_score_vec(i)));
+ }
+ }
+
+ // Add new leaves.
+ Eigen::Tensor<float, 1, 1> zeros(num_classes);
+ zeros.setZero();
+ const float zero_score = WeightedGiniImpurity(zeros);
+ for (int leaf = end_of_tree; leaf < end_of_tree + num_new_leaves; leaf++) {
+ if (depths(leaf) < max_depth_) {
+ leaf_heap->push(std::make_pair(leaf, zero_score));
+ }
+ }
+ }
+
+ // Finds the next available or newly-finished accumulator.
+ void FindNextAccumulator(Tensor totals_tensor,
+ const std::set<int32>& finished_accumulators,
+ int* current) {
+ ++(*current);
+ const auto totals = totals_tensor.tensor<float, 2>();
+ for (; *current < totals_tensor.shape().dim_size(0); ++(*current)) {
+ if (totals(*current, 0) < 0 ||
+ finished_accumulators.find(*current) != finished_accumulators.end()) {
+ return;
+ }
+ }
+ *current = -1;
+ }
+
+ int32 max_depth_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("UpdateFertileSlots").Device(DEVICE_CPU),
+ UpdateFertileSlots);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
new file mode 100644
index 0000000..ead6198
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.best_splits_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BestSplitsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.finished = [3, 5]
+ self.node_map = [-1, -1, -1, 0, -1, 3, -1, -1, -1]
+ self.candidate_counts = [[[50., 60., 40., 3.], [70., 30., 70., 30.]],
+ [[0., 0., 0., 0.], [0., 0., 0., 0.]],
+ [[0., 0., 0., 0.], [0., 0., 0., 0.]],
+ [[10., 10., 10., 10.], [10., 5., 5., 10.]]]
+ self.total_counts = [[100., 100., 100., 100.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.],
+ [100., 100., 100., 100.]]
+ self.ops = training_ops.Load()
+
+ def testSimple(self):
+ with self.test_session():
+ split_indices = self.ops.best_splits(
+ self.finished, self.node_map, self.candidate_counts,
+ self.total_counts)
+
+ self.assertAllEqual([0, 1], split_indices.eval())
+
+ def testNoFinished(self):
+ with self.test_session():
+ split_indices = self.ops.best_splits(
+ [], self.node_map, self.candidate_counts, self.total_counts)
+
+ self.assertAllEqual([], split_indices.eval())
+
+ def testBadInput(self):
+ del self.total_counts[1]
+
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Number of accumulators should be the same in pcw_candidate_splits '
+ 'and pcw_total_splits.'):
+ self.ops.best_splits(
+ self.finished, self.node_map, self.candidate_counts,
+ self.total_counts).eval()
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
new file mode 100644
index 0000000..e93bb17
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
@@ -0,0 +1,94 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.count_extremely_random_stats."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+ self.input_labels = [0, 1, 2, 3]
+ self.tree = [[1, 0], [-1, 0], [-1, 0]]
+ self.tree_thresholds = [0., 0., 0.]
+ self.node_map = [-1, 0, -1]
+ self.split_features = [[1], [-1]]
+ self.split_thresholds = [[1.], [0.]]
+ self.ops = training_ops.Load()
+
+ def testSimple(self):
+ with self.test_session():
+ (pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
+ pcw_totals_delta, leaves) = (
+ self.ops.count_extremely_random_stats(
+ self.input_data, self.input_labels, self.tree,
+ self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, num_classes=4))
+
+ self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
+ [0., 0., 1., 1.]],
+ pcw_node.eval())
+ self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval())
+ self.assertAllEqual([1.], pcw_splits_delta.eval())
+ self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval())
+ self.assertAllEqual([1., 1.], pcw_totals_delta.eval())
+ self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+
+ def testNoAccumulators(self):
+ with self.test_session():
+ (pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
+ pcw_totals_delta, leaves) = (
+ self.ops.count_extremely_random_stats(
+ self.input_data, self.input_labels, self.tree,
+ self.tree_thresholds, [-1] * 3,
+ self.split_features, self.split_thresholds, num_classes=4))
+
+ self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
+ [0., 0., 1., 1.]],
+ pcw_node.eval())
+ self.assertEquals((0, 3), pcw_splits_indices.eval().shape)
+ self.assertAllEqual([], pcw_splits_delta.eval())
+ self.assertEquals((0, 2), pcw_totals_indices.eval().shape)
+ self.assertAllEqual([], pcw_totals_delta.eval())
+ self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+
+ def testBadInput(self):
+ del self.node_map[-1]
+
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Number of nodes should be the same in '
+ 'tree, tree_thresholds, and node_to_accumulator'):
+ pcw_node, _, _, _, _, _ = (
+ self.ops.count_extremely_random_stats(
+ self.input_data, self.input_labels, self.tree,
+ self.tree_thresholds, self.node_map,
+ self.split_features, self.split_thresholds, num_classes=4))
+
+ self.assertAllEqual([], pcw_node.eval())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
new file mode 100644
index 0000000..84583ae
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
@@ -0,0 +1,63 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.finished_nodes_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class FinishedNodesTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.leaves = [1, 3, 4]
+ self.node_map = [-1, -1, -1, 0, 1, -1]
+ self.pcw_total_splits = [[3, 3], [4, 7], [0, 0], [0, 0], [0, 0]]
+ self.ops = training_ops.Load()
+
+ def testSimple(self):
+ with self.test_session():
+ finished = self.ops.finished_nodes(self.leaves, self.node_map,
+ self.pcw_total_splits,
+ num_split_after_samples=10)
+
+ self.assertAllEqual([4], finished.eval())
+
+ def testNoAccumulators(self):
+ with self.test_session():
+ finished = self.ops.finished_nodes(self.leaves, [-1] * 6,
+ self.pcw_total_splits,
+ num_split_after_samples=10)
+
+ self.assertAllEqual([], finished.eval())
+
+ def testBadInput(self):
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'leaf_tensor should be one-dimensional'):
+ finished = self.ops.finished_nodes([self.leaves], self.node_map,
+ self.pcw_total_splits,
+ num_split_after_samples=10)
+
+ self.assertAllEqual([], finished.eval())
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/grow_tree_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/grow_tree_op_test.py
new file mode 100644
index 0000000..8177e2d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/grow_tree_op_test.py
@@ -0,0 +1,105 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.grow_tree_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class GrowTreeTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.tree = tf.Variable([[1, 0], [-1, 0], [-1, 0],
+ [-2, 0], [-2, 0], [-2, 0], [-2, 0]])
+ self.tree_thresholds = tf.Variable([0., 0., 0., 0., 0., 0., 0.])
+ self.eot = tf.Variable([3])
+ self.depths = tf.Variable([1, 2, 2, -1, -1, -1, -1])
+ self.node_map = [-1, 0, 1, -1, -1, -1, -1]
+ self.finished = [1, 2]
+ self.best_splits = [2, 3]
+ self.split_features = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ self.split_thresholds = [[10., 20., 30., 40.], [50., 60., 70., 80.]]
+ self.ops = training_ops.Load()
+
+ def testSimple(self):
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ update_list, tree_updates, threshold_updates, depth_updates, new_eot = (
+ self.ops.grow_tree(self.eot, self.depths, self.node_map,
+ self.finished, self.best_splits,
+ self.split_features, self.split_thresholds))
+
+ self.assertAllEqual([1, 3, 4, 2, 5, 6], update_list.eval())
+ self.assertAllEqual(
+ [[3, 3], [-1, -1], [-1, -1], [5, 8], [-1, -1], [-1, -1]],
+ tree_updates.eval())
+ self.assertAllEqual([30.0, 0.0, 0.0, 80.0, 0.0, 0.0],
+ threshold_updates.eval())
+ self.assertAllEqual([2, 3, 3, 2, 3, 3], depth_updates.eval())
+ self.assertAllEqual([7], new_eot.eval())
+
+ def testNoRoomToGrow(self):
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ # Even though there's one free node, there needs to be 2 to grow.
+ tf.assign(self.eot, [6]).eval()
+
+ update_list, tree_updates, threshold_updates, depth_updates, new_eot = (
+ self.ops.grow_tree(self.eot, self.depths, self.node_map,
+ self.finished, self.best_splits,
+ self.split_features, self.split_thresholds))
+
+ self.assertAllEqual([], update_list.eval())
+ self.assertEquals((0, 2), tree_updates.eval().shape)
+ self.assertAllEqual([], threshold_updates.eval())
+ self.assertAllEqual([], depth_updates.eval())
+ self.assertAllEqual([6], new_eot.eval())
+
+ def testNoFinished(self):
+ with self.test_session():
+ tf.initialize_all_variables().run()
+
+ update_list, tree_updates, threshold_updates, depth_updates, new_eot = (
+ self.ops.grow_tree(self.eot, self.depths, self.node_map, [], [],
+ self.split_features, self.split_thresholds))
+
+ self.assertAllEqual([], update_list.eval())
+ self.assertAllEqual((0, 2), tree_updates.eval().shape)
+ self.assertAllEqual([], threshold_updates.eval())
+ self.assertAllEqual([], depth_updates.eval())
+ self.assertAllEqual([3], new_eot.eval())
+
+ def testBadInput(self):
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ with self.assertRaisesOpError(
+ 'Number of finished nodes should be the same in finished and '
+ 'best_splits.'):
+ update_list, _, _, _, _ = (
+ self.ops.grow_tree(self.eot, self.depths, self.node_map,
+ [], self.best_splits,
+ self.split_features, self.split_thresholds))
+ self.assertAllEqual([], update_list.eval())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
new file mode 100644
index 0000000..d050747
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
@@ -0,0 +1,79 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.sample_inputs_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class SampleInputsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.input_data = [[-1., 10.], [-10., 2.], # node 1
+ [20., 50.], [1., -2.]] # node 2
+ self.node_map = [-1, 0, 1]
+ self.leaves = [1, 1, 2, 2]
+ self.split_features = [[-1, -1, -1], [1, 0, -1], [-1, -1, -1]]
+ self.split_thresholds = [[0., 0., 0.], [5., -2., 0.], [0., 0., 0.]]
+ self.ops = training_ops.Load()
+
+ def testSimple(self):
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ indices, feature_updates, threshold_updates = (
+ self.ops.sample_inputs(
+ self.input_data, self.node_map, self.leaves, self.split_features,
+ self.split_thresholds, split_initializations_per_input=1,
+ split_sampling_random_seed=3))
+ self.assertAllEqual([1, 0], indices.eval())
+ self.assertAllEqual([[1, 0, 1], [0, 0, -1]],
+ feature_updates.eval())
+ self.assertAllEqual([[5., -2., 50.], [-1., -10., 0.]],
+ threshold_updates.eval())
+
+ def testNoAccumulators(self):
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ indices, feature_updates, threshold_updates = (
+ self.ops.sample_inputs(
+ self.input_data, [-1] * 3, self.leaves, self.split_features,
+ self.split_thresholds, split_initializations_per_input=1,
+ split_sampling_random_seed=3))
+ self.assertAllEqual([], indices.eval())
+ self.assertAllEqual((0, 3), feature_updates.eval().shape)
+ self.assertAllEqual((0, 3), threshold_updates.eval().shape)
+
+ def testBadInput(self):
+ del self.split_features[1]
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ with self.assertRaisesOpError(
+ 'split_features and split_thresholds should be the same shape.'):
+ indices, _, _ = self.ops.sample_inputs(
+ self.input_data, self.node_map, self.leaves, self.split_features,
+ self.split_thresholds, split_initializations_per_input=1,
+ split_sampling_random_seed=3)
+ self.assertAllEqual([], indices.eval())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
new file mode 100644
index 0000000..467ffed
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -0,0 +1,82 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.scatter_add_ndim_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class ScatterAddNdimTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.ops = training_ops.Load()
+
+ def test1dim(self):
+ input_data = tf.Variable([1., 2., 3., 4., 5., 6.,
+ 7., 8., 9., 10., 11., 12.])
+ indices = [[1], [10]]
+ updates = [100., 200.]
+
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ self.ops.scatter_add_ndim(input_data, indices, updates).run()
+ self.assertAllEqual([1., 102., 3., 4., 5., 6.,
+ 7., 8., 9., 10., 211., 12.], input_data.eval())
+
+ def test3dim(self):
+ input_data = tf.Variable([[[1., 2., 3.], [4., 5., 6.]],
+ [[7., 8., 9.], [10., 11., 12.]]])
+ indices = [[0, 0, 1], [1, 1, 2]]
+ updates = [100., 200.]
+
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ self.ops.scatter_add_ndim(input_data, indices, updates).run()
+ self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
+ [[7., 8., 9.], [10., 11., 212.]]], input_data.eval())
+
+ def testNoUpdates(self):
+ init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
+ input_data = tf.Variable(init_val)
+ indices = []
+ updates = []
+
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ self.ops.scatter_add_ndim(input_data, indices, updates).run()
+ self.assertAllEqual(init_val, input_data.eval())
+
+ def testBadInput(self):
+ init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
+ input_data = tf.Variable(init_val)
+ indices = [[0, 0, 1], [1, 1, 2]]
+ updates = [100.]
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ with self.assertRaisesOpError(
+ 'Number of updates should be same as number of indices.'):
+ self.ops.scatter_add_ndim(input_data, indices, updates).run()
+ self.assertAllEqual(init_val, input_data.eval())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
new file mode 100644
index 0000000..743cd83
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
@@ -0,0 +1,103 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.tree_predictions_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import inference_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class TreePredictionsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.ops = inference_ops.Load()
+
+ def testSimple(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+
+ tree = [[1, 0], [-1, 0], [-1, 0]]
+ tree_thresholds = [0., 0., 0.]
+ node_pcw = [[0.3, 0.4, 0.3], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ input_data, tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=1)
+
+ self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8],
+ [0.5, 0.25, 0.25], [0.5, 0.25, 0.25]],
+ predictions.eval())
+
+ def testBackoffToParent(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+
+ tree = [[1, 0], [-1, 0], [-1, 0]]
+ tree_thresholds = [0., 0., 0.]
+ node_pcw = [[3.0, 9.0, 3.0], [1.0, 1.0, 3.0], [5.0, 20.0, 0.0]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ input_data, tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=10)
+
+ # Node 2 has enough data, but Node 1 needs to combine with the parent
+ # counts.
+ self.assertAllClose([[0.2, 0.4, 0.4], [0.2, 0.4, 0.4],
+ [0.2, 0.8, 0.0], [0.2, 0.8, 0.0]],
+ predictions.eval())
+
+ def testNoInput(self):
+ input_data = []
+
+ tree = [[1, 0], [-1, 0], [-1, 0]]
+ tree_thresholds = [0., 0., 0.]
+ node_pcw = [[0.3, 0.4, 0.3], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ input_data, tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=10)
+
+ self.assertEquals((0, 3), predictions.eval().shape)
+
+ def testBadInput(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+
+ tree = [[1, 0], [-1, 0], [-1, 0]]
+ tree_thresholds = [0., 0.] # not enough nodes.
+ node_pcw = [[0.3, 0.4, 0.3], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Number of nodes should be the same in tree, tree_thresholds '
+ 'and node_pcw.'):
+ predictions = self.ops.tree_predictions(
+ input_data, tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=10)
+
+ self.assertEquals((0, 3), predictions.eval().shape)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
new file mode 100644
index 0000000..36c232f
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
@@ -0,0 +1,102 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.allocate_deallocate_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ # tree is:
+ # 0
+ # 1 2
+ # 3 4 5 6
+ self.finished = [2]
+ self.depths = [1, 2, 2, 3, 3, 3, 3]
+ self.non_fertile_leaves = [3, 4]
+ self.non_fertile_leaf_scores = [10., 15.]
+ self.end_of_tree = [5]
+ self.node_map = [-1, -1, 0, -1, -1, -1, -1]
+ self.candidate_counts = [[[10., 20.], [30., 10.]]]
+ self.total_counts = [[40., 40.]]
+ self.ops = training_ops.Load()
+
+ def testSimple(self):
+ with self.test_session():
+ (node_map_updates, accumulators_cleared, accumulators_allocated,
+ new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
+ self.end_of_tree, self.depths, self.candidate_counts,
+ self.total_counts, self.node_map, max_depth=4)
+
+ self.assertAllEqual([[2, 4], [-1, 0]], node_map_updates.eval())
+ self.assertAllEqual([], accumulators_cleared.eval())
+ self.assertAllEqual([0], accumulators_allocated.eval())
+ self.assertAllEqual([3, 5, 6], new_nfl.eval())
+ self.assertAllEqual([10., 1., 1.], new_nfl_scores.eval())
+
+ def testReachedMaxDepth(self):
+ with self.test_session():
+ (node_map_updates, accumulators_cleared, accumulators_allocated,
+ new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
+ self.end_of_tree, self.depths, self.candidate_counts,
+ self.total_counts, self.node_map, max_depth=3)
+
+ self.assertAllEqual([[2], [-1]], node_map_updates.eval())
+ self.assertAllEqual([0], accumulators_cleared.eval())
+ self.assertAllEqual([], accumulators_allocated.eval())
+ self.assertAllEqual([-1], new_nfl.eval())
+ self.assertAllEqual([0.0], new_nfl_scores.eval())
+
+ def testNoFinished(self):
+ with self.test_session():
+ (node_map_updates, accumulators_cleared, accumulators_allocated,
+ new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+ [], self.non_fertile_leaves, self.non_fertile_leaf_scores,
+ self.end_of_tree, self.depths, self.candidate_counts,
+ self.total_counts, self.node_map, max_depth=4)
+
+ self.assertAllEqual((2, 0), node_map_updates.eval().shape)
+ self.assertAllEqual([], accumulators_cleared.eval())
+ self.assertAllEqual([], accumulators_allocated.eval())
+ self.assertAllEqual([4, 3], new_nfl.eval())
+ self.assertAllEqual([15., 10.], new_nfl_scores.eval())
+
+ def testBadInput(self):
+ del self.non_fertile_leaf_scores[-1]
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Number of non fertile leaves should be the same in '
+ 'non_fertile_leaves and non_fertile_leaf_scores.'):
+ (node_map_updates, _, _, _, _) = self.ops.update_fertile_slots(
+ self.finished, self.non_fertile_leaves,
+ self.non_fertile_leaf_scores, self.end_of_tree, self.depths,
+ self.candidate_counts, self.total_counts,
+ self.node_map, max_depth=4)
+ self.assertAllEqual((2, 0), node_map_updates.eval().shape)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
new file mode 100644
index 0000000..7cad6a8
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
@@ -0,0 +1,63 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for BrainTree v2 tree evaluation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+
+import tensorflow as tf
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+INFERENCE_OPS_FILE = '_inference_ops.so'
+
+_inference_ops = None
+_ops_lock = threading.Lock()
+
+
+ops.NoGradient('TreePredictions')
+
+
+@ops.RegisterShape('TreePredictions')
+def TreePredictions(op):
+ """Shape function for TreePredictions Op."""
+ num_points = op.inputs[0].get_shape()[0].value
+ num_classes = op.inputs[3].get_shape()[1].value
+ # The output of TreePredictions is
+ # [node_pcw(evaluate_tree(x), c) for c in classes for x in input_data].
+ return [tensor_shape.TensorShape([num_points, num_classes])]
+
+
+# Workaround for the fact that importing tensorflow imports contrib
+# (even if a user isn't using this or any other contrib op), but
+# there's not yet any guarantee that the shared object exists.
+# In which case, "import tensorflow" will always crash, even for users that
+# never use contrib.
+def Load():
+ """Load the inference ops library and return the loaded module."""
+ with _ops_lock:
+ global _inference_ops
+ if not _inference_ops:
+ data_files_path = tf.resource_loader.get_data_files_path()
+ tf.logging.info('data path: %s', data_files_path)
+ _inference_ops = tf.load_op_library(os.path.join(
+ data_files_path, INFERENCE_OPS_FILE))
+
+ assert _inference_ops, 'Could not load inference_ops.so'
+ return _inference_ops
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
new file mode 100644
index 0000000..8ca2491
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -0,0 +1,110 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for BrainTree v2 training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+
+import tensorflow as tf
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+TRAINING_OPS_FILE = '_training_ops.so'
+
+_training_ops = None
+_ops_lock = threading.Lock()
+
+ops.NoGradient('CountExtremelyRandomStats')
+ops.NoGradient('SampleInputs')
+ops.NoGradient('BestSplits')
+ops.NoGradient('GrowTree')
+ops.NoGradient('FinishedNodes')
+ops.NoGradient('ScatterAddNdim')
+ops.NoGradient('UpdateFertileSlots')
+
+
+@ops.RegisterShape('CountExtremelyRandomStats')
+def _CountExtremelyRandomStatsShape(op):
+ """Shape function for CountExtremelyRandomStats Op."""
+ num_points = op.inputs[0].get_shape()[0].value
+ num_nodes = op.inputs[2].get_shape()[0].value
+ num_classes = op.get_attr('num_classes')
+ # The output of TraverseTree is [leaf_node_index(x) for x in input_data].
+ return [tensor_shape.TensorShape([num_nodes, num_classes]), # node pcw
+ tensor_shape.TensorShape([None, 3]),
+ tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([None, 2]),
+ tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([num_points])]
+
+
+@ops.RegisterShape('SampleInputs')
+def _SampleInputsShape(op):
+ """Shape function for SampleInputs Op."""
+ num_splits = op.inputs[3].get_shape()[1].value
+ return [[None], [None, num_splits], [None, num_splits]]
+
+
+@ops.RegisterShape('BestSplits')
+def _BestSplitsShape(op):
+ num_finished = op.inputs[0].get_shape()[0].value
+ return [tensor_shape.TensorShape([num_finished])]
+
+
+@ops.RegisterShape('GrowTree')
+def _GrowTreeShape(unused_op):
+ """Shape function for GrowTree Op."""
+ return [[None], [None, 2], [None], [None], [1]]
+
+
+@ops.RegisterShape('FinishedNodes')
+def _FinishedNodesShape(unused_op):
+ """Shape function for FinishedNodes Op."""
+ return [[None]]
+
+
+@ops.RegisterShape('ScatterAddNdim')
+def _ScatterAddNdimShape(unused_op):
+ """Shape function for ScatterAddNdim Op."""
+ return []
+
+
+@ops.RegisterShape('UpdateFertileSlots')
+def _UpdateFertileSlotsShape(unused_op):
+ """Shape function for UpdateFertileSlots Op."""
+ return [[None, 2], [None], [None], [None], [None]]
+
+
+# Workaround for the fact that importing tensorflow imports contrib
+# (even if a user isn't using this or any other contrib op), but
+# there's not yet any guarantee that the shared object exists.
+# In which case, "import tensorflow" will always crash, even for users that
+# never use contrib.
+def Load():
+ """Load training ops library and return the loaded module."""
+ with _ops_lock:
+ global _training_ops
+ if not _training_ops:
+ data_files_path = tf.resource_loader.get_data_files_path()
+ tf.logging.info('data path: %s', data_files_path)
+ _training_ops = tf.load_op_library(os.path.join(
+ data_files_path, TRAINING_OPS_FILE))
+
+ assert _training_ops, 'Could not load _training_ops.so'
+ return _training_ops
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
new file mode 100644
index 0000000..27b2b6d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -0,0 +1,537 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extremely random forest graph builder. go/brain-tree."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import inference_ops
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+
+# Default parameter values. These are all only used if the corresponding
+# parameter is not specified when constructing the ForestHParams.
+flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest')
+flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.')
+flags.DEFINE_float(
+ 'samples_to_decide', 25.0,
+ 'Only decide on a split, or only fully use a leaf, after this many '
+ 'training samples have been seen.')
+
+# If tree[i][0] equals this value, then i is a leaf node.
+LEAF_NODE = -1
+
+
+# A convenience class for holding random forest hyperparameters.
+#
+# To just get some good default parameters, use:
+# hparams = ForestHParams(num_classes=2, num_features=40).fill()
+#
+# Note that num_classes can not be inferred and so must always be specified.
+# Also, either num_splits_to_consider or num_features should be set.
+#
+# To override specific values, pass them to the constructor:
+# hparams = ForestHParams(num_classes=5, num_trees=10, num_features=5).fill()
+#
+# TODO(thomaswc): Inherit from tf.HParams when that is publicly available.
+class ForestHParams(object):
+ """A base class for holding hyperparameters and calculating good defaults."""
+
+ def __init__(self, **kwargs):
+ for name, value in kwargs.iteritems():
+ setattr(self, name, value)
+
+ def values(self):
+ return self.__dict__
+
+ def fill(self):
+ """Intelligently sets any non-specific parameters."""
+ # Fail fast if num_classes isn't set.
+ _ = getattr(self, 'num_classes')
+
+ self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
+ self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
+
+ # Allow each tree to be unbalanced by up to a factor of 2.
+ self.max_depth = getattr(self, 'max_depth',
+ int(2 * math.ceil(math.log(self.max_nodes, 2))))
+
+ # The Random Forest literature recommends sqrt(# features) for
+ # classification problems, and p/3 for regression problems.
+ # TODO(thomaswc): Consider capping this for large number of features.
+ if not getattr(self, 'num_splits_to_consider', None):
+ self.num_splits_to_consider = max(10, int(
+ math.ceil(math.sqrt(self.num_features))))
+
+ # max_fertile_nodes doesn't effect performance, only training speed.
+ # We therefore set it primarily based upon space considerations.
+ # Each fertile node takes up num_splits_to_consider times as much
+ # as space as a non-fertile node. We want the fertile nodes to in
+ # total only take up as much space as the non-fertile nodes, so
+ num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider))
+ # But always use at least 1000 accumulate slots.
+ num_fertile = max(num_fertile, 1000)
+ self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
+ # But it also never needs to be larger than the number of leaves,
+ # which is max_nodes / 2.
+ self.max_fertile_nodes = min(self.max_nodes,
+ int(math.ceil(self.max_fertile_nodes / 2.0)))
+
+ # split_after_samples and valid_leaf_threshold should be about the same.
+ # Therefore, if either is set, use it to set the other. Otherwise, fall
+ # back on FLAGS.samples_to_decide.
+ samples_to_decide = (
+ getattr(self, 'split_after_samples',
+ getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide)))
+ self.split_after_samples = getattr(self, 'split_after_samples',
+ samples_to_decide)
+ self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold',
+ samples_to_decide)
+
+ # We have num_splits_to_consider slots to fill, and we want to spend
+ # approximately split_after_samples samples initializing them.
+ num_split_initializiations_per_input = max(1, int(math.floor(
+ self.num_splits_to_consider / self.split_after_samples)))
+ self.split_initializations_per_input = getattr(
+ self, 'split_initializations_per_input',
+ num_split_initializiations_per_input)
+
+ # If base_random_seed is 0, the current time will be used to seed the
+ # random number generators for each tree. If non-zero, the i-th tree
+ # will be seeded with base_random_seed + i.
+ self.base_random_seed = getattr(self, 'base_random_seed', 0)
+
+ return self
+
+
+# A simple container to hold the training variables for a single tree.
+class TreeTrainingVariables(object):
+
+ def __init__(self, params):
+ self.tree = tf.Variable(
+ [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1), name='tree')
+ self.tree_thresholds = tf.Variable(
+ [-1.0] * (params.max_nodes), name='tree_thresholds')
+ self.tree_depths = tf.Variable(
+ [1] * (params.max_nodes), name='tree_depths')
+ self.end_of_tree = tf.Variable([1], name='end_of_tree')
+
+ self.non_fertile_leaves = tf.Variable([0], name='non_fertile_leaves')
+ self.non_fertile_leaf_scores = tf.Variable(
+ [1.0], name='non_fertile_leaf_scores')
+
+ self.node_to_accumulator_map = tf.Variable(
+ [-1] * params.max_nodes, name='node_to_accumulator_map')
+
+ self.candidate_split_features = tf.Variable(
+ [[-1] * params.num_splits_to_consider] * params.max_fertile_nodes,
+ name='candidate_split_features')
+ self.candidate_split_thresholds = tf.Variable(
+ [[0.0] * params.num_splits_to_consider] * params.max_fertile_nodes,
+ name='candidate_split_thresholds')
+
+ self.node_per_class_weights = tf.Variable(
+ [[0.0] * params.num_classes] * params.max_nodes,
+ name='node_per_class_weights')
+ self.candidate_split_per_class_weights = tf.Variable(
+ [[[0.0] * params.num_classes] * params.num_splits_to_consider] *
+ params.max_fertile_nodes,
+ name='candidate_split_per_class_weights')
+ self.total_split_per_class_weights = tf.Variable(
+ [[-1.0] * params.num_classes] * params.max_fertile_nodes,
+ name='total_split_per_class_weights')
+
+
+class ForestStats(object):
+
+ def __init__(self, tree_stats, params):
+ """A simple container for stats about a forest."""
+ self.tree_stats = tree_stats
+ self.params = params
+
+ def get_average(self, thing):
+ val = 0.0
+ for i in range(self.params.num_trees):
+ val += getattr(self.tree_stats[i], thing)
+
+ return val / self.params.num_trees
+
+
+class TreeStats(object):
+
+ def __init__(self, num_nodes, num_leaves):
+ self.num_nodes = num_nodes
+ self.num_leaves = num_leaves
+
+
+def get_tree_stats(variables, unused_params, session):
+ num_nodes = variables.end_of_tree.eval(session=session) - 1
+ num_leaves = tf.where(
+ tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])),
+ LEAF_NODE)).eval(session=session).shape[0]
+ return TreeStats(num_nodes, num_leaves)
+
+
+def get_forest_stats(variables, params, session):
+
+ tree_stats = []
+ for i in range(params.num_trees):
+ tree_stats.append(get_tree_stats(variables[i], params, session))
+
+ return ForestStats(tree_stats, params)
+
+
+class ForestTrainingVariables(object):
+ """A container for a forests training data, consisting of multiple trees.
+
+ Instantiates a TreeTrainingVariables object for each tree. We override the
+ __getitem__ and __setitem__ function so that usage looks like this:
+
+ forest_variables = ForestTrainingVariables(params)
+
+ ... forest_variables.tree ...
+ """
+
+ def __init__(self, params):
+ self.variables = [TreeTrainingVariables(params)
+ for _ in range(params.num_trees)]
+
+ def __setitem__(self, t, val):
+ self.variables[t] = val
+
+ def __getitem__(self, t):
+ return self.variables[t]
+
+
+class RandomForestGraphs(object):
+ """Builds TF graphs for random forest training and inference."""
+
+ def __init__(self, params):
+ self.params = params
+ self.variables = ForestTrainingVariables(self.params)
+ self.trees = [RandomTreeGraphs(self.variables[i], self.params,
+ training_ops.Load(), inference_ops.Load())
+ for i in range(self.params.num_trees)]
+
+ def training_graph(self, input_data, input_labels):
+ """Constructs a TF graph for training a random forest.
+
+ Args:
+ input_data: A tensor or placeholder for input data.
+ input_labels: A tensor or placeholder for labels associated with
+ input_data.
+
+ Returns:
+ The last op in the random forest training graph.
+ """
+ tree_graphs = []
+ for i in range(self.params.num_trees):
+ tf.logging.info('Constructing tree %d', i)
+ seed = self.params.base_random_seed
+ if seed != 0:
+ seed += i
+ tree_graphs.append(self.trees[i].training_graph(
+ input_data, input_labels, seed))
+ return tf.group(*tree_graphs)
+
+ def inference_graph(self, input_data):
+ """Constructs a TF graph for evaluating a random forest.
+
+ Args:
+ input_data: A tensor or placeholder for input data.
+
+ Returns:
+ The last op in the random forest inference graph.
+ """
+ probabilities = []
+ for i in range(self.params.num_trees):
+ probabilities.append(self.trees[i].inference_graph(input_data))
+ all_predict = tf.pack(probabilities)
+ return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+
+ def average_impurity(self):
+ """Constructs a TF graph for evaluating the leaf impurity of a forest.
+
+ Returns:
+ The last op in the graph.
+ """
+ impurities = []
+ for i in range(self.params.num_trees):
+ impurities.append(self.trees[i].average_impurity(self.variables[i]))
+ return tf.reduce_mean(tf.pack(impurities))
+
+
+class RandomTreeGraphs(object):
+ """Builds TF graphs for random tree training and inference."""
+
+ def __init__(self, variables, params, t_ops, i_ops):
+ self.training_ops = t_ops
+ self.inference_ops = i_ops
+ self.variables = variables
+ self.params = params
+
+ def _gini(self, class_counts):
+ """Calculate the Gini impurity.
+
+ If c(i) denotes the i-th class count and c = sum_i c(i) then
+ score = 1 - sum_i ( c(i) / c )^2
+
+ Args:
+ class_counts: A 2-D tensor of per-class counts, from either
+ candidate_split_per_class_weights or total_split_per_class_weights.
+
+ Returns:
+ A 1-D tensor of the Gini impurities for each row in the input.
+ """
+ smoothed = 1.0 + class_counts
+ sums = tf.reduce_sum(smoothed, 1)
+ sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+
+ return 1.0 - sum_squares / (sums * sums)
+
+ def _weighted_gini(self, class_counts):
+ """Our split score is the Gini impurity times the number of examples.
+
+ If c(i) denotes the i-th class count and c = sum_i c(i) then
+ score = c * (1 - sum_i ( c(i) / c )^2 )
+ = c - sum_i c(i)^2 / c
+ Args:
+ class_counts: A 2-D tensor of per-class counts, from either
+ candidate_split_per_class_weights or total_split_per_class_weights.
+
+ Returns:
+ A 1-D tensor of the Gini impurities for each row in the input.
+ """
+ smoothed = 1.0 + class_counts
+ sums = tf.reduce_sum(smoothed, 1)
+ sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+
+ return sums - sum_squares / sums
+
+ def training_graph(self, input_data, input_labels, random_seed):
+ """Constructs a TF graph for training a random tree.
+
+ Args:
+ input_data: A tensor or placeholder for input data.
+ input_labels: A tensor or placeholder for labels associated with
+ input_data.
+ random_seed: The random number generator seed to use for this tree. 0
+ means use the current time as the seed.
+
+ Returns:
+ The last op in the random tree training graph.
+ """
+ # Count extremely random stats.
+ (pcw_node_delta, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
+ pcw_totals_delta, input_leaves) = (
+ self.training_ops.count_extremely_random_stats(
+ input_data, input_labels, self.variables.tree,
+ self.variables.tree_thresholds,
+ self.variables.node_to_accumulator_map,
+ self.variables.candidate_split_features,
+ self.variables.candidate_split_thresholds,
+ num_classes=self.params.num_classes))
+ node_update_op = tf.assign_add(self.variables.node_per_class_weights,
+ pcw_node_delta)
+ candidate_update_op = self.training_ops.scatter_add_ndim(
+ self.variables.candidate_split_per_class_weights,
+ pcw_splits_indices, pcw_splits_delta)
+
+ totals_update_op = self.training_ops.scatter_add_ndim(
+ self.variables.total_split_per_class_weights, pcw_totals_indices,
+ pcw_totals_delta)
+
+ # Sample inputs.
+ update_indices, feature_updates, threshold_updates = (
+ self.training_ops.sample_inputs(
+ input_data, self.variables.node_to_accumulator_map,
+ input_leaves, self.variables.candidate_split_features,
+ self.variables.candidate_split_thresholds,
+ split_initializations_per_input=(
+ self.params.split_initializations_per_input),
+ split_sampling_random_seed=random_seed))
+ update_features_op = tf.scatter_update(
+ self.variables.candidate_split_features, update_indices,
+ feature_updates)
+ update_thresholds_op = tf.scatter_update(
+ self.variables.candidate_split_thresholds, update_indices,
+ threshold_updates)
+
+ # Calculate finished nodes.
+ with tf.control_dependencies([totals_update_op]):
+ children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
+ squeeze_dims=[1])
+ is_leaf = tf.equal(LEAF_NODE, children)
+ leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
+ finished = self.training_ops.finished_nodes(
+ leaves, self.variables.node_to_accumulator_map,
+ self.variables.total_split_per_class_weights,
+ num_split_after_samples=self.params.split_after_samples)
+
+ # Update leaf scores.
+ # TODO(gilberth): Optimize this. It currently calculates counts for
+ # every non-fertile leaf.
+ with tf.control_dependencies([node_update_op]):
+ def f1():
+ return self.variables.non_fertile_leaf_scores
+ def f2():
+ counts = tf.gather(self.variables.node_per_class_weights,
+ self.variables.non_fertile_leaves)
+ new_scores = self._weighted_gini(counts)
+ return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
+
+ # Because we can't have tf.self.variables of size 0, we have to put in a
+ # garbage value of -1 in there. Here we check for that so we don't
+ # try to index into node_per_class_weights in a tf.gather with a negative
+ # number.
+ update_nonfertile_leaves_scores_op = tf.cond(tf.less(
+ self.variables.non_fertile_leaves[0], 0), f1, f2)
+
+ # Calculate best splits.
+ with tf.control_dependencies([candidate_update_op, totals_update_op]):
+ split_indices = self.training_ops.best_splits(
+ finished, self.variables.node_to_accumulator_map,
+ self.variables.candidate_split_per_class_weights,
+ self.variables.total_split_per_class_weights)
+
+ # Grow tree.
+ with tf.control_dependencies([update_features_op, update_thresholds_op]):
+ (tree_update_indices, tree_children_updates,
+ tree_threshold_updates, tree_depth_updates, new_eot) = (
+ self.training_ops.grow_tree(
+ self.variables.end_of_tree, self.variables.tree_depths,
+ self.variables.node_to_accumulator_map, finished, split_indices,
+ self.variables.candidate_split_features,
+ self.variables.candidate_split_thresholds))
+ tree_update_op = tf.scatter_update(
+ self.variables.tree, tree_update_indices, tree_children_updates)
+ threhsolds_update_op = tf.scatter_update(
+ self.variables.tree_thresholds, tree_update_indices,
+ tree_threshold_updates)
+ depth_update_op = tf.scatter_update(
+ self.variables.tree_depths, tree_update_indices, tree_depth_updates)
+
+ # Update fertile slots.
+ with tf.control_dependencies([update_nonfertile_leaves_scores_op,
+ depth_update_op]):
+ (node_map_updates, accumulators_cleared, accumulators_allocated,
+ new_nonfertile_leaves, new_nonfertile_leaves_scores) = (
+ self.training_ops.update_fertile_slots(
+ finished, self.variables.non_fertile_leaves,
+ self.variables.non_fertile_leaf_scores,
+ self.variables.end_of_tree, self.variables.tree_depths,
+ self.variables.candidate_split_per_class_weights,
+ self.variables.total_split_per_class_weights,
+ self.variables.node_to_accumulator_map,
+ max_depth=self.params.max_depth))
+
+ # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
+ # used it to calculate new leaves.
+ gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves])
+ eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot)
+
+ updates = []
+ updates.append(eot_update_op)
+ updates.append(tree_update_op)
+ updates.append(threhsolds_update_op)
+ updates.append(tf.assign(
+ self.variables.non_fertile_leaves, new_nonfertile_leaves,
+ validate_shape=False))
+ updates.append(tf.assign(
+ self.variables.non_fertile_leaf_scores,
+ new_nonfertile_leaves_scores, validate_shape=False))
+
+ updates.append(tf.scatter_update(
+ self.variables.node_to_accumulator_map,
+ tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]),
+ squeeze_dims=[0]),
+ tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]),
+ squeeze_dims=[0])))
+
+ cleared_and_allocated_accumulators = tf.concat(
+ 0, [accumulators_cleared, accumulators_allocated])
+ # Calculate values to put into scatter update for candidate counts.
+ # Candidate split counts are always reset back to 0 for both cleared
+ # and allocated accumulators. This means some accumulators might be doubly
+ # reset to 0 if the were released and not allocated, then later allocated.
+ candidate_pcw_values = tf.tile(
+ tf.expand_dims(tf.expand_dims(
+ tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32),
+ 1), 2),
+ [1, self.params.num_splits_to_consider, self.params.num_classes])
+ updates.append(tf.scatter_update(
+ self.variables.candidate_split_per_class_weights,
+ cleared_and_allocated_accumulators, candidate_pcw_values))
+
+ # Calculate values to put into scatter update for total counts.
+ total_cleared = tf.tile(
+ tf.expand_dims(
+ tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1),
+ [1, self.params.num_classes])
+ total_reset = tf.tile(
+ tf.expand_dims(
+ tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1),
+ [1, self.params.num_classes])
+ total_pcw_updates = tf.concat(0, [total_cleared, total_reset])
+ updates.append(tf.scatter_update(
+ self.variables.total_split_per_class_weights,
+ cleared_and_allocated_accumulators, total_pcw_updates))
+
+ # Calculate values to put into scatter update for candidate splits.
+ split_features_updates = tf.tile(
+ tf.expand_dims(
+ tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1),
+ [1, self.params.num_splits_to_consider])
+ updates.append(tf.scatter_update(
+ self.variables.candidate_split_features,
+ cleared_and_allocated_accumulators, split_features_updates))
+
+ return tf.group(*updates)
+
+ def inference_graph(self, input_data):
+ """Constructs a TF graph for evaluating a random tree.
+
+ Args:
+ input_data: A tensor or placeholder for input data.
+
+ Returns:
+ The last op in the random tree inference graph.
+ """
+ return self.inference_ops.tree_predictions(
+ input_data, self.variables.tree, self.variables.tree_thresholds,
+ self.variables.node_per_class_weights,
+ valid_leaf_threshold=self.params.valid_leaf_threshold)
+
+ def average_impurity(self):
+ """Constructs a TF graph for evaluating the average leaf impurity of a tree.
+
+ Returns:
+ The last op in the graph.
+ """
+ children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
+ squeeze_dims=[1])
+ is_leaf = tf.equal(LEAF_NODE, children)
+ leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
+ counts = tf.gather(self.variables.node_per_class_weights, leaves)
+ impurity = self._weighted_gini(counts)
+ return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
new file mode 100644
index 0000000..e4846cb
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -0,0 +1,55 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.tensor_forest."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python import tensor_forest
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class TensorForestTest(test_util.TensorFlowTestCase):
+
+ def testTrainingConstruction(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+ input_labels = [0, 1, 2, 3]
+
+ params = tensor_forest.ForestHParams(
+ num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
+
+ graph_builder = tensor_forest.RandomForestGraphs(params)
+ graph = graph_builder.training_graph(input_data, input_labels)
+ self.assertTrue(isinstance(graph, tf.Operation))
+
+ def testInferenceConstruction(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+
+ params = tensor_forest.ForestHParams(
+ num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
+
+ graph_builder = tensor_forest.RandomForestGraphs(params)
+ graph = graph_builder.inference_graph(input_data)
+ self.assertTrue(isinstance(graph, tf.Tensor))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/tools/ci_build/builds/test_installation.sh b/tensorflow/tools/ci_build/builds/test_installation.sh
index b6385ac..2b87ac3 100755
--- a/tensorflow/tools/ci_build/builds/test_installation.sh
+++ b/tensorflow/tools/ci_build/builds/test_installation.sh
@@ -76,7 +76,8 @@
# Test blacklist: GPU-only
PY_TEST_GPU_BLACKLIST="${PY_TEST_GPU_BLACKLIST}:"\
-"tensorflow/python/framework/function_test.py"
+"tensorflow/python/framework/function_test.py:"\
+"tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py"
# Tests that should be run in the exclusive mode (i.e., not parallel with
# other tests)