Add extremely random forest implementation to TF/contrib.
Change: 118888391
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 366af6b..a5ba85f 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -21,6 +21,7 @@
         "//tensorflow/contrib/lookup:lookup_py",
         "//tensorflow/contrib/losses:losses_py",
         "//tensorflow/contrib/skflow",
+        "//tensorflow/contrib/tensor_forest:tensor_forest_py",
         "//tensorflow/contrib/testing:testing_py",
         "//tensorflow/contrib/util:util_py",
     ],
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
new file mode 100644
index 0000000..33c66a6b
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -0,0 +1,194 @@
+# Tensorflow code for training random forests.
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+)
+
+cc_library(
+    name = "tree_utils",
+    srcs = ["core/ops/tree_utils.cc"],
+    hdrs = [
+        "core/ops/tree_utils.h",
+    ],
+    deps = [
+        "//google/protobuf",
+        "//tensorflow/core:framework_headers_lib",
+        "//third_party/eigen3",
+    ],
+)
+
+tf_custom_op_library(
+    name = "python/ops/_inference_ops.so",
+    srcs = [
+        "core/ops/tree_predictions_op.cc",
+    ],
+    deps = [":tree_utils"],
+)
+
+tf_custom_op_library(
+    name = "python/ops/_training_ops.so",
+    srcs = [
+        "core/ops/best_splits_op.cc",
+        "core/ops/count_extremely_random_stats_op.cc",
+        "core/ops/finished_nodes_op.cc",
+        "core/ops/grow_tree_op.cc",
+        "core/ops/sample_inputs_op.cc",
+        "core/ops/scatter_add_ndim_op.cc",
+        "core/ops/update_fertile_slots_op.cc",
+    ],
+    deps = [":tree_utils"],
+)
+
+py_library(
+    name = "ops_lib",
+    srcs = [
+        "__init__.py",
+        "python/ops/inference_ops.py",
+        "python/ops/training_ops.py",
+    ],
+    data = [
+        "python/ops/_inference_ops.so",
+        "python/ops/_training_ops.so",
+    ],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "best_splits_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/best_splits_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "count_extremely_random_stats_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/count_extremely_random_stats_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "grow_tree_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/grow_tree_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "finished_nodes_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/finished_nodes_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "sample_inputs_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/sample_inputs_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "scatter_add_ndim_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "tree_predictions_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/tree_predictions_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_test(
+    name = "update_fertile_slots_op_test",
+    size = "small",
+    srcs = ["python/kernel_tests/update_fertile_slots_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_library(
+    name = "tensor_forest_py",
+    srcs = ["python/tensor_forest.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ops_lib",
+    ],
+)
+
+py_test(
+    name = "tensor_forest_test",
+    size = "small",
+    srcs = ["python/tensor_forest_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":tensor_forest_py",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
diff --git a/tensorflow/contrib/tensor_forest/OWNERS b/tensorflow/contrib/tensor_forest/OWNERS
new file mode 100644
index 0000000..becb4d2
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/OWNERS
@@ -0,0 +1,3 @@
+dsculley
+gilberth
+thomaswc
\ No newline at end of file
diff --git a/tensorflow/contrib/tensor_forest/__init__.py b/tensorflow/contrib/tensor_forest/__init__.py
new file mode 100644
index 0000000..878d9c5
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Imports training and inference custom ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensor_forest.python.ops.inference_ops import *
+from tensorflow.contrib.tensor_forest.python.ops.training_ops import *
diff --git a/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc b/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc
new file mode 100644
index 0000000..3c454be
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc
@@ -0,0 +1,117 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// BestSplits returns the index of the best candidate for each finished node.
+// This decision is based on the Gini score of the pcw_candidate_split counts,
+// and the right-branch-taken counts inferred from pcw_total_splits.
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+
+namespace tensorflow {
+
+using tensorforest::BestFeature;
+
+
+REGISTER_OP("BestSplits")
+  .Input("finished_nodes: int32")
+  .Input("node_to_accumulator: int32")
+  .Input("pcw_candidate_splits: float")
+  .Input("pcw_total_splits: float")
+  .Output("split_indices: int32")
+  .Doc(R"doc(
+  Returns the index of the best split for each finished node.
+
+  The best split is the split with the lowest weighted Gini impurity,
+  as calculated from the statistics in `pcw_candidate_splits` and
+  `pcw_total_splits`.
+
+  finished_nodes:= A 1-d int32 tensor containing the indices of finished nodes.
+  node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+    fertile node i, or -1 if node i isn't fertile.
+  pcw_candidate_splits: `pcw_candidate_splits[a][s][c]` records how many
+    training examples have class c and have ended up in the fertile node
+    associated with accumulator slot a and then taken the *left* branch of
+    candidate split s.
+  pcw_total_splits: `pcw_total_splits[a][c]` records how many training examples
+    have class c and have ended up in the fertile node associated with
+    accumulator slot a.  Between that and `pcw_candidate_splits`, the number of
+    examples taking the right branch of a split can be reconstructed.
+  split_indices: `split_indices[i]` contains the index of the split to use for
+    `finished_nodes[i]`.
+)doc");
+
+
+class BestSplits : public OpKernel {
+ public:
+  explicit BestSplits(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& finished = context->input(0);
+    const Tensor& node_to_accumulator = context->input(1);
+    const Tensor& pcw_candidate_splits = context->input(2);
+    const Tensor& pcw_total_splits = context->input(3);
+
+    OP_REQUIRES(context, finished.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "finished should be one-dimensional"));
+    OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "node_to_accumulator should be one-dimensional"));
+
+    OP_REQUIRES(context, pcw_candidate_splits.shape().dims() == 3,
+                errors::InvalidArgument(
+                    "pcw_candidate_splits should be three-dimensional"));
+    OP_REQUIRES(context, pcw_total_splits.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "pcw_total_splits should be two-dimensional"));
+
+    OP_REQUIRES(
+        context,
+        pcw_candidate_splits.shape().dim_size(0) ==
+        pcw_total_splits.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of accumulators should be the same in pcw_candidate_splits "
+            "and pcw_total_splits."));
+
+    Tensor* output_splits = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, finished.shape(),
+                                            &output_splits));
+    auto best_splits = output_splits->unaligned_flat<int32>();
+
+    const auto finished_vec = finished.unaligned_flat<int32>();
+    const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+
+    const int32 num_finished = finished.shape().dim_size(0);
+
+    for (int i = 0; i < num_finished; i++) {
+      const int32 node = finished_vec(i);
+      const int32 accumulator = node_map(node);
+      if (accumulator < 0) {
+        LOG(ERROR) << "Something has gone wrong, we got a finished node that "
+                   << "doesn't have an accumulator allocated to it.";
+        continue;
+      }
+      best_splits(i) = BestFeature(pcw_total_splits,
+                                   pcw_candidate_splits, accumulator);
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("BestSplits").Device(DEVICE_CPU), BestSplits);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
new file mode 100644
index 0000000..ab5ac9c
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
@@ -0,0 +1,331 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// CountExtremelyRandomStats outputs count-deltas that should be added to
+// the node pcws, candidate split pcws, and total split pcws.  It also outputs
+// the leaves that each input arrived to for use in SampleInputs.  This is the
+// only op that involves tree traversal, and is constructed so that it can
+// be run in parallel on separate batches of data.
+#include <unordered_map>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+using std::get;
+using std::make_tuple;
+using std::pair;
+using std::tuple;
+
+using tensorforest::CHILDREN_INDEX;
+using tensorforest::FEATURE_INDEX;
+using tensorforest::LEAF_NODE;
+using tensorforest::FREE_NODE;
+
+using tensorforest::DecideNode;
+using tensorforest::Initialize;
+using tensorforest::IsAllInitialized;
+
+REGISTER_OP("CountExtremelyRandomStats")
+  .Attr("num_classes: int32")
+  .Input("input_data: float")
+
+  .Input("input_labels: int32")
+
+  .Input("tree: int32")
+  .Input("tree_thresholds: float")
+
+  .Input("node_to_accumulator: int32")
+
+  .Input("candidate_split_features: int32")
+  .Input("candidate_split_thresholds: float")
+
+  .Output("pcw_node_delta: float")
+  .Output("pcw_splits_indices: int32")
+  .Output("pcw_candidate_splits_delta: float")
+  .Output("pcw_totals_indices: int32")
+  .Output("pcw_total_splits_delta: float")
+
+  .Output("leaves: int32")
+  .Doc(R"doc(
+   Calculates incremental statistics for a batch of training data.
+
+   Each training example in `input_data` is sent through the decision tree
+   represented by `tree` and `tree_thresholds`.  `pcw_node_delta[i]` is
+   incremented for every node i that it passes through, and the leaf it ends up
+   in is recorded in `leaves[i]`.  Then, if the leaf is fertile and
+   initialized, the statistics for its corresponding accumulator slot
+   are updated in in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`.
+
+   The attr `num_classes` is needed to appropriately size the outputs.
+
+   input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+     gives the j-th feature of the i-th input.
+   input_labels: The training batch's labels; `input_labels[i]` is the class
+     of the i-th input.
+   tree:= A 2-d int32 tensor.  `tree[0][i]` gives the index of the left child
+     of the i-th node, `tree[0][i] + 1` gives the index of the right child of
+     the i-th node, and `tree[1][i]` gives the index of the feature used to
+     split the i-th node.
+   tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
+     node.
+   node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
+     is it's accumulator slot.  Otherwise, `node_to_accumulator[i]` is -1.
+   candidate_split_features: `candidate_split_features[a][s]` is the
+     index of the feature being considered by split s of accumulator slot a.
+   candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
+     threshold value being considered by split s of accumulator slot a.
+   pcw_node_delta: `pcw_node_delta[i][c]` is the number of training examples
+     in this training batch with class c that passed through node i.
+   pcw_splits_indices:= A 2-d tensor of shape (?, 3).
+     `pcw_splits_indices[i]` gives the coordinates of an entry in
+     candidate_split_per_class_weights that needs to be updated.
+     This is meant to be passed with `pcw_candidate_splits_delta` to a
+     scatter_add for candidate_split_per_class_weights:
+       training_ops.scatter_add_ndim(candidate_split_per_class_weights,
+           pcw_splits_indices, pcw_candidate_splits_delta)
+   pcw_candidate_splits_delta: `pcw_candidate_splits_delta[i]` is the
+     number of training examples in this training batch that correspond to
+     the i-th entry in `pcw_splits_indices` which took the *left* branch of
+     candidate split.
+   pcw_totals_indices: 'pcw_totals_indices` contains the indices (accumulator,
+     class) into total_per_class_weights to update with pcw_total_splits_delta.
+   pcw_total_splits_delta: `pcw_total_splits_delta[i]` is the number of
+     training examples in this batch that ended up in the fertile
+     node with accumulator and class indicated by `pcw_totals_indices[i]`.
+   leaves: `leaves[i]` is the leaf that input i ended up in.
+)doc");
+
+
+class CountExtremelyRandomStats : public OpKernel {
+ public:
+  explicit CountExtremelyRandomStats(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "num_classes", &num_classes_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_data = context->input(0);
+    const Tensor& input_labels = context->input(1);
+    const Tensor& tree_tensor = context->input(2);
+    const Tensor& tree_thresholds = context->input(3);
+    const Tensor& node_to_accumulator = context->input(4);
+    const Tensor& candidate_split_features = context->input(5);
+    const Tensor& candidate_split_thresholds = context->input(6);
+
+    // Check inputs.
+    OP_REQUIRES(context, input_data.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "input_data should be two-dimensional"));
+    OP_REQUIRES(context, input_labels.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "input_labels should be one-dimensional"));
+
+    OP_REQUIRES(context, tree_tensor.shape().dims() == 2,
+            errors::InvalidArgument(
+                "tree should be two-dimensional"));
+    OP_REQUIRES(context, tree_thresholds.shape().dims() == 1,
+            errors::InvalidArgument(
+                "tree_thresholds should be one-dimensional"));
+    OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+            errors::InvalidArgument(
+                "node_to_accumulator should be one-dimensional"));
+    OP_REQUIRES(context, candidate_split_features.shape().dims() == 2,
+            errors::InvalidArgument(
+                "candidate_split_features should be two-dimensional"));
+    OP_REQUIRES(context, candidate_split_thresholds.shape().dims() == 2,
+            errors::InvalidArgument(
+                "candidate_split_thresholds should be two-dimensional"));
+
+    OP_REQUIRES(
+        context,
+        input_data.shape().dim_size(0) == input_labels.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of inputs should be the same in "
+            "input_data and input_labels."));
+    OP_REQUIRES(
+        context,
+        tree_tensor.shape().dim_size(0) ==
+        tree_thresholds.shape().dim_size(0) &&
+        tree_tensor.shape().dim_size(0) ==
+        node_to_accumulator.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of nodes should be the same in "
+            "tree, tree_thresholds, and node_to_accumulator"));
+    OP_REQUIRES(
+        context,
+        candidate_split_features.shape() == candidate_split_thresholds.shape(),
+        errors::InvalidArgument(
+            "candidate_split_features and candidate_split_thresholds should be "
+            "the same shape."));
+
+    const int32 num_splits = candidate_split_features.shape().dim_size(1);
+
+    // node pcw delta
+    Tensor* output_node_pcw_delta = nullptr;
+    TensorShape node_pcw_shape;
+    node_pcw_shape.AddDim(tree_tensor.shape().dim_size(0));
+    node_pcw_shape.AddDim(num_classes_);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, node_pcw_shape,
+                                            &output_node_pcw_delta));
+    Initialize<float>(*output_node_pcw_delta, 0);
+    auto out_node = output_node_pcw_delta->tensor<float, 2>();
+
+    // leaves
+    Tensor* output_leaves = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(5, input_labels.shape(),
+                                            &output_leaves));
+    auto out_leaves = output_leaves->unaligned_flat<int32>();
+
+    const auto tree = tree_tensor.tensor<int32, 2>();
+    const auto thresholds = tree_thresholds.unaligned_flat<float>();
+    const auto labels = input_labels.unaligned_flat<int32>();
+    const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+    const auto split_features = candidate_split_features.tensor<int32, 2>();
+    const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
+
+    const int32 num_data = input_data.shape().dim_size(0);
+
+    // <accumulator, class> -> count delta
+    std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
+    // <accumulator, split, class> -> count delta
+    std::unordered_map<tuple<int32, int32, int32>,
+        int32, TupleIntHash> split_delta;
+    for (int i = 0; i < num_data; i++) {
+      const Tensor point = input_data.Slice(i, i+1);
+      int node_index = 0;
+      while (true) {
+        const int32 label = labels(i);
+        ++out_node(node_index, label);
+        int32 left_child = tree(node_index, CHILDREN_INDEX);
+        if (left_child == LEAF_NODE) {
+          out_leaves(i) = node_index;
+          const int32 accumulator = node_map(node_index);
+          // If the leaf is not fertile or is not yet initialized, we don't
+          // count it in the candidate/total split per-class-weights because
+          // it won't have any candidate splits yet.
+          if (accumulator >= 0 &&
+              IsAllInitialized(
+                  candidate_split_features.Slice(accumulator,
+                                                 accumulator + 1))) {
+            ++total_delta[std::make_pair(accumulator, label)];
+            for (int split = 0; split < num_splits; split++) {
+              if (!DecideNode(point, split_features(accumulator, split),
+                              split_thresholds(accumulator, split))) {
+                ++split_delta[make_tuple(accumulator, split, label)];
+              }
+            }
+          }
+          break;
+        } else if (left_child == FREE_NODE) {
+          LOG(ERROR) << "Reached a free node, not good.";
+          out_leaves(i) = FREE_NODE;
+          break;
+        }
+        node_index = left_child +
+            DecideNode(point, tree(node_index, FEATURE_INDEX),
+                       thresholds(node_index));
+      }
+    }
+
+     // candidate splits pcw indices
+    Tensor* output_candidate_pcw_indices = nullptr;
+    TensorShape candidate_pcw_shape;
+    candidate_pcw_shape.AddDim(split_delta.size());
+    candidate_pcw_shape.AddDim(3);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, candidate_pcw_shape,
+                                            &output_candidate_pcw_indices));
+    auto out_candidate_indices =
+        output_candidate_pcw_indices->tensor<int32, 2>();
+
+    // candidate splits pcw delta
+    Tensor* output_candidate_pcw_delta = nullptr;
+    TensorShape candidate_pcw_delta_shape;
+    candidate_pcw_delta_shape.AddDim(split_delta.size());
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(2, candidate_pcw_delta_shape,
+                                            &output_candidate_pcw_delta));
+    auto out_candidate = output_candidate_pcw_delta->unaligned_flat<float>();
+
+    // total splits indices
+    Tensor* output_total_pcw_indices = nullptr;
+    TensorShape total_pcw_shape;
+    total_pcw_shape.AddDim(total_delta.size());
+    total_pcw_shape.AddDim(2);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(3, total_pcw_shape,
+                                            &output_total_pcw_indices));
+    auto out_total_indices = output_total_pcw_indices->tensor<int32, 2>();
+
+    // total splits delta
+    Tensor* output_total_pcw_delta = nullptr;
+    TensorShape total_pcw_delta_shape;
+    total_pcw_delta_shape.AddDim(total_delta.size());
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(4, total_pcw_delta_shape,
+                                            &output_total_pcw_delta));
+    auto out_total = output_total_pcw_delta->unaligned_flat<float>();
+
+    // Copy total deltas to output.
+    int32 output_slot = 0;
+    for (const auto& updates : total_delta) {
+      out_total_indices(output_slot, 0) = updates.first.first;
+      out_total_indices(output_slot, 1) = updates.first.second;
+      out_total(output_slot) = updates.second;
+      ++output_slot;
+    }
+
+    // Copy split deltas to output.
+    output_slot = 0;
+    for (const auto& updates : split_delta) {
+      out_candidate_indices(output_slot, 0) = get<0>(updates.first);
+      out_candidate_indices(output_slot, 1) = get<1>(updates.first);
+      out_candidate_indices(output_slot, 2) = get<2>(updates.first);
+      out_candidate(output_slot) = updates.second;
+      ++output_slot;
+    }
+  }
+
+ private:
+  struct PairIntHash {
+   public:
+    std::size_t operator()(const std::pair<int, int>& x) const {
+      return std::hash<int>()(x.first) ^ std::hash<int>()(x.second);
+    }
+  };
+
+  struct TupleIntHash {
+   public:
+    std::size_t operator()(const std::tuple<int32, int32, int32>& x) const {
+      return std::hash<int32>()(get<0>(x)) ^ std::hash<int32>()(get<1>(x)) ^
+          std::hash<int32>()(get<2>(x));
+    }
+  };
+
+  int32 num_classes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CountExtremelyRandomStats").Device(DEVICE_CPU),
+                        CountExtremelyRandomStats);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
new file mode 100644
index 0000000..804ed44
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
@@ -0,0 +1,112 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// FinishedNodes returns a 1-D tensor listing the nodes that are finished
+// accumulating.
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+using tensorforest::Sum;
+
+REGISTER_OP("FinishedNodes")
+  .Attr("num_split_after_samples: int32")
+  .Input("leaves: int32")
+  .Input("node_to_accumulator: int32")
+  .Input("pcw_total_splits: float")
+
+  .Output("finished: int32")
+  .Doc(R"doc(
+  Determines which of the given leaf nodes are done accumulating.
+
+  leaves:= A 1-d int32 tensor.  Lists the nodes that are currently leaves.
+  node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
+   is it's accumulator slot.  Otherwise, `node_to_accumulator[i]` is -1.
+  pcw_total_splits: `pcw_total_splits[a][c]` records how many training examples
+   have class c and have ended up in the fertile node associated with
+   accumulator slot a.  Between that and `pcw_candidate_splits`, the number of
+   examples taking the right branch of a split can be reconstructed.
+  finished:= A 1-d int32 tensor. Contains the nodes that have total split
+   counts greater or equal to the num_split_after_samples attribute.
+)doc");
+
+
+class FinishedNodes : public OpKernel {
+ public:
+  explicit FinishedNodes(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "num_split_after_samples", &num_split_after_samples_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& leaf_tensor = context->input(0);
+    const Tensor& node_to_accumulator = context->input(1);
+    const Tensor& pcw_total_splits = context->input(2);
+
+    OP_REQUIRES(context, leaf_tensor.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "leaf_tensor should be one-dimensional"));
+    OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "node_to_accumulator should be one-dimensional"));
+    OP_REQUIRES(context, pcw_total_splits.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "pcw_total_splits should be two-dimensional"));
+
+    const auto leaves = leaf_tensor.unaligned_flat<int32>();
+    const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+
+    const int32 num_leaves = leaf_tensor.shape().dim_size(0);
+
+    std::vector<int32> finished;
+    for (int i = 0; i < num_leaves; i++) {
+      const int32 leaf = leaves(i);
+      const int32 accumulator = node_map(leaf);
+      if (accumulator < 0) {
+        continue;
+      }
+
+      if (Sum<float>(pcw_total_splits.Slice(accumulator, accumulator + 1)) >=
+          num_split_after_samples_) {
+        finished.push_back(leaf);
+      }
+    }
+
+    // Copy to output.
+    Tensor* output_finished = nullptr;
+    TensorShape finished_shape;
+    finished_shape.AddDim(finished.size());
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, finished_shape,
+                                            &output_finished));
+    auto out_finished = output_finished->unaligned_flat<int32>();
+
+    for (int32 i = 0; i < finished.size(); i++) {
+      out_finished(i) = finished[i];
+    }
+  }
+
+ private:
+  int32 num_split_after_samples_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU),
+                        FinishedNodes);
+
+}  // namespace tensorflow
+
diff --git a/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc b/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc
new file mode 100644
index 0000000..2fde74b
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/grow_tree_op.cc
@@ -0,0 +1,259 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// GrowTree adds children to the tree for finished nodes by using the
+// end_of_tree tensor as an indicator for where free nodes are in the
+// pre-allocated tree tensor.
+// For example if the tree is:
+//    1, -1, -1, -2, -2, -2, ...
+// Then end_of_tree should be 3 (the first -2, or "free" slot in the tensor).
+// If node 1 is now finished, the tree tensor after this op would be:
+//    1, 3, -1, -1, -1, -2, ...
+// and end_of_tree would be 5.
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+
+
+namespace tensorflow {
+
+using tensorforest::CHILDREN_INDEX;
+using tensorforest::FEATURE_INDEX;
+
+using tensorforest::LEAF_NODE;
+
+
+REGISTER_OP("GrowTree")
+  .Input("end_of_tree: int32")
+  .Input("tree_depths: int32")
+  .Input("node_to_accumulator: int32")
+  .Input("finished_nodes: int32")
+  .Input("best_splits: int32")
+  .Input("candidate_split_features: int32")
+  .Input("candidate_split_thresholds: float")
+  .Output("nodes_to_update: int32")
+  .Output("tree_updates: int32")
+  .Output("threshold_updates: float")
+  .Output("depth_updates: int32")
+  .Output("new_end_of_tree: int32")
+  .Doc(R"doc(
+  Output the tree changes needed to resolve fertile nodes.
+
+  Previous Ops have already decided which fertile nodes want to stop being
+  fertile and what their best candidate split should be and have passed that
+  information to this Op in `finished_nodes` and `best_splits`.  This Op
+  merely checks that there is still space in tree to add new nodes, and if
+  so, writes out the sparse updates needed for the fertile nodes to be
+  resolved to the tree, threshold and depth tensors.
+
+  end_of_tree: `end_of_tree[0]` is the number of allocated nodes, or
+    equivalently the index of the first free node in the tree tensor.
+  tree_depths: `tree_depths[i]` is the depth in the tree of node i.
+  node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+    fertile node i, or -1 if node i isn't fertile.
+  finished_nodes:= A 1-d int32 tensor containing the indices of finished nodes.
+  best_splits: `best_splits[i]` is the index of the best split for
+    `finished_nodes[i]`.
+  candidate_split_features: `candidate_split_features[a][s]` is the feature
+    being considered for split s of the fertile node associated with
+    accumulator slot a.
+  candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
+    threshold value being considered for split s of the fertile node associated
+    with accumulator slot a.
+  nodes_to_update:= A 1-d int32 tensor containing the node indices that need
+    updating.
+  tree_updates: The updates to apply to the 2-d tree tensor.  Intended to be
+    used with `tf.scatter_update(tree, nodes_to_update, tree_updates)`.
+  threshold_updates: The updates to apply to the 1-d thresholds tensor.
+    Intended to be used with
+    `tf.scatter_update(thresholds, nodes_to_update, threshold_updates)`.
+  depth_updates: The updates to apply to the 1-d depths tensor.  Intended to
+    be used with `tf.scatter_update(depths, nodes_to_update, depth_updates)`.
+  new_end_of_tree: `new_end_of_tree[0]` is the new size of the tree.
+)doc");
+
+class GrowTree : public OpKernel {
+ public:
+  explicit GrowTree(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& end_of_tree = context->input(0);
+    const Tensor& tree_depths = context->input(1);
+    const Tensor& node_to_accumulator = context->input(2);
+    const Tensor& finished = context->input(3);
+    const Tensor& best_splits = context->input(4);
+    const Tensor& candidate_split_features = context->input(5);
+    const Tensor& candidate_split_thresholds = context->input(6);
+
+    OP_REQUIRES(context, end_of_tree.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "end_of_tree should be one-dimensional"));
+    OP_REQUIRES(context, tree_depths.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "tree_depths should be one-dimensional"));
+    OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "node_to_accumulator should be one-dimensional"));
+    OP_REQUIRES(context, finished.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "finished should be one-dimensional"));
+    OP_REQUIRES(context, best_splits.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "best_splits should be one-dimensional"));
+    OP_REQUIRES(context, candidate_split_features.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "candidate_split_features should be two-dimensional"));
+    OP_REQUIRES(context, candidate_split_thresholds.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "candidate_split_thresholds should be two-dimensional"));
+
+    OP_REQUIRES(
+        context,
+        finished.shape().dim_size(0) ==
+        best_splits.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of finished nodes should be the same in finished and "
+            "best_splits."));
+    OP_REQUIRES(
+        context,
+        tree_depths.shape().dim_size(0) ==
+        node_to_accumulator.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of nodes should be the same in tree_depths and "
+            "node_to_accumulator."));
+    OP_REQUIRES(
+        context,
+        candidate_split_features.shape().dim_size(0) ==
+        candidate_split_thresholds.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of accumulators should be the same in "
+            "candidate_split_features and candidate_split_thresholds."));
+    OP_REQUIRES(
+        context,
+        candidate_split_features.shape().dim_size(1) ==
+        candidate_split_thresholds.shape().dim_size(1),
+        errors::InvalidArgument(
+            "Number of splits should be the same in "
+            "candidate_split_features and candidate_split_thresholds."));
+
+    int32 current_end_of_tree = end_of_tree.unaligned_flat<int32>()(0);
+    const auto depths = tree_depths.unaligned_flat<int32>();
+    const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+    const auto finished_vec = finished.unaligned_flat<int32>();
+    const auto best_vec = best_splits.unaligned_flat<int32>();
+    const auto split_features = candidate_split_features.tensor<int32, 2>();
+    const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
+
+    const int32 num_finished = finished.shape().dim_size(0);
+    const int32 num_nodes = node_to_accumulator.shape().dim_size(0);
+
+    // Converting a leaf node into an internal node requires space for its
+    // two children.
+    int32 remaining_node_space = (num_nodes - current_end_of_tree) / 2;
+    int32 nodes_we_can_allocate = std::min(num_finished, remaining_node_space);
+    // Each conversion touches three nodes: the transitioning node and its
+    // two new children.
+    int32 num_updates = 3 * nodes_we_can_allocate;
+
+    Tensor* nodes_to_update_tensor = nullptr;
+    TensorShape nodes_to_update_shape;
+    nodes_to_update_shape.AddDim(num_updates);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, nodes_to_update_shape,
+                                            &nodes_to_update_tensor));
+    auto nodes_to_update_flat = nodes_to_update_tensor->tensor<int32, 1>();
+
+    Tensor* tree_updates_tensor = nullptr;
+    TensorShape tree_updates_shape;
+    tree_updates_shape.AddDim(num_updates);
+    tree_updates_shape.AddDim(2);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, tree_updates_shape,
+                                            &tree_updates_tensor));
+    auto tree_updates_flat = tree_updates_tensor->tensor<int32, 2>();
+
+    Tensor* threshold_updates_tensor = nullptr;
+    TensorShape threshold_updates_shape;
+    threshold_updates_shape.AddDim(num_updates);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(2, threshold_updates_shape,
+                                            &threshold_updates_tensor));
+    auto threshold_updates_flat = threshold_updates_tensor->tensor<float, 1>();
+
+    Tensor* depth_updates_tensor = nullptr;
+    TensorShape depth_updates_shape;
+    depth_updates_shape.AddDim(num_updates);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(3, depth_updates_shape,
+                                            &depth_updates_tensor));
+    auto depth_updates_flat = depth_updates_tensor->tensor<int32, 1>();
+
+    int output_slot = 0;
+    for (int i = 0; i < nodes_we_can_allocate; i++) {
+      const int32 node = finished_vec(i);
+      const int32 best = best_vec(i);
+      const int32 accumulator = node_map(node);
+      if (accumulator < 0) {
+        LOG(ERROR) << "Finished node doesn't have an accumulator.";
+        continue;
+      }
+
+      if (current_end_of_tree >= num_nodes - 1) {
+        LOG(ERROR) << "Could not grow tree any further.";
+        return;
+      }
+      const int32 left = current_end_of_tree;
+      nodes_to_update_flat(output_slot) = node;
+
+      tree_updates_flat(output_slot, CHILDREN_INDEX) = left;
+      tree_updates_flat(output_slot, FEATURE_INDEX) =
+          split_features(accumulator, best);
+      threshold_updates_flat(output_slot) = split_thresholds(accumulator, best);
+      depth_updates_flat(output_slot) = depths(node);
+      output_slot++;
+
+      nodes_to_update_flat(output_slot) = left;
+      tree_updates_flat(output_slot, CHILDREN_INDEX) = LEAF_NODE;
+      tree_updates_flat(output_slot, FEATURE_INDEX) = -1;
+      threshold_updates_flat(output_slot) = 0.0;
+      depth_updates_flat(output_slot) = depths(node) + 1;
+      output_slot++;
+
+      nodes_to_update_flat(output_slot) = left + 1;
+      tree_updates_flat(output_slot, CHILDREN_INDEX) = LEAF_NODE;
+      tree_updates_flat(output_slot, FEATURE_INDEX) = -1;
+      threshold_updates_flat(output_slot) = 0.0;
+      depth_updates_flat(output_slot) = depths(node) + 1;
+      output_slot++;
+
+      current_end_of_tree += 2;
+    }
+
+    Tensor* new_end_of_tree_tensor = nullptr;
+    TensorShape new_end_of_tree_shape;
+    new_end_of_tree_shape.AddDim(1);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(4, new_end_of_tree_shape,
+                                            &new_end_of_tree_tensor));
+    auto new_end_of_tree_flat = new_end_of_tree_tensor->tensor<int32, 1>();
+    new_end_of_tree_flat(0) = current_end_of_tree;
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GrowTree").Device(DEVICE_CPU), GrowTree);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
new file mode 100644
index 0000000..452a683
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
@@ -0,0 +1,241 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// SampleInputs initializes candidate splits/threshold values randomly
+// from incoming data for not-yet-initialized fertile nodes.
+#include <ctime>
+#include <unordered_map>
+#include <set>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+using tensorforest::IsAllInitialized;
+
+
+// TODO(gilberth): Reinitialize candidate splits that were finished last round.
+REGISTER_OP("SampleInputs")
+  .Attr("split_initializations_per_input: int32")
+  .Attr("split_sampling_random_seed: int32")
+  .Input("input_data: float")
+  .Input("node_to_accumulator: int32")
+  .Input("leaves: int32")
+  .Input("candidate_split_features: int32")
+  .Input("candidate_split_thresholds: float")
+  .Output("accumulators_to_update: int32")
+  .Output("new_split_feature_rows: int32")
+  .Output("new_split_threshold_rows: float")
+  .Doc(R"doc(
+  Initializes candidate splits for newly fertile nodes.
+
+  In an extremely random forest, we don't consider all possible threshold
+  values for a candidate split feature, but rather only a sampling of them.
+  This Op takes those samples from the training data in `input_data`.  The
+  feature and threshold samples are stored in tensors that are indexed by
+  accumulator slot, so for each input, we must first look up which leaf
+  it ended up in (using `leaves`) and then which accumulator slot if any
+  that leaf maps to (using `node_to_accumulator`).
+
+  The attribute `split_initializations_per_input` controls how many splits
+  a single training example can initialize, and the attribute
+  `split_sampling_random_seed` sets the random number generator's seed
+  (a value of 0 means use the current time as the seed).
+
+  input_data: The features for the current batch of training data.
+    `input_data[i][j]` is the j-th feature of the i-th input.
+  node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the
+    associated accumulator slot.  For non-fertile nodes, it is -1.
+  leaves: `leaves[i]` is the leaf that the i-th input landed in, as
+    calculated by CountExtremelyRandomStats.
+  candidate_split_features: The current features for the candidate splits;
+    `candidate_split_features[a][s]` is the index of the feature being
+    considered by split s in accumulator slot a.
+  candidate_split_thresholds: The current thresholds for the candidate splits;
+    `candidate_split_thresholds[a][s]` is the threshold value being
+    considered by split s in accumulator slot a.
+  accumulators_to_update: A list of the accumulators to change in the
+    candidate_split_features and candidate_split_thresholds tensors.
+  new_split_feature_rows: The new values for the candidate_split_features
+    tensor.  Intended to be used with
+    `tf.scatter_update(candidate_split_features,
+                       accumulators_to_update,
+                       new_split_feature_rows)`
+  new_split_threshold_rows:  The new values for the candidate_split_thresholds
+    tensor.  Intended to be used with
+    `tf.scatter_update(candidate_split_thresholds,
+                       accumulators_to_update,
+                       new_split_feature_thresholds)`
+)doc");
+
+class SampleInputs : public OpKernel {
+ public:
+  explicit SampleInputs(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "split_initializations_per_input", &split_initializations_per_input_));
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "split_sampling_random_seed", &split_sampling_random_seed_));
+    // Set up the random number generator.
+    if (split_sampling_random_seed_ == 0) {
+      uint64 time_seed = static_cast<uint64>(std::time(NULL));
+      single_rand_ = std::unique_ptr<random::PhiloxRandom>(
+          new random::PhiloxRandom(time_seed));
+    } else {
+      single_rand_ = std::unique_ptr<random::PhiloxRandom>(
+          new random::PhiloxRandom(split_sampling_random_seed_));
+    }
+
+    rng_ = std::unique_ptr<random::SimplePhilox>(
+        new random::SimplePhilox(single_rand_.get()));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_data = context->input(0);
+    const Tensor& node_to_accumulator = context->input(1);
+    const Tensor& leaves = context->input(2);
+    const Tensor& split_features = context->input(3);
+    const Tensor& split_thresholds = context->input(4);
+
+    OP_REQUIRES(context, input_data.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "input_data should be two-dimensional"));
+    OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "node_to_accumulator should be one-dimensional"));
+    OP_REQUIRES(context, leaves.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "leaves should be one-dimensional"));
+    OP_REQUIRES(context, split_features.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "split_features should be two-dimensional"));
+    OP_REQUIRES(context, split_thresholds.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "split_thresholds should be two-dimensional"));
+
+    OP_REQUIRES(
+        context,
+        split_features.shape() == split_thresholds.shape(),
+        errors::InvalidArgument(
+            "split_features and split_thresholds should be the same shape."));
+
+    const auto inputs = input_data.tensor<float, 2>();
+    const auto leaves_vec = leaves.unaligned_flat<int32>();
+    const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+    const auto features = split_features.tensor<int32, 2>();
+    const auto thresholds = split_thresholds.tensor<float, 2>();
+
+    const int32 num_data = leaves.shape().dim_size(0);
+    const int32 num_splits = split_features.shape().dim_size(1);
+    const int32 num_features = input_data.shape().dim_size(1);
+
+    std::unordered_map<int32, std::set<int32>> accumulator_to_leaves;
+
+    // The first pass just calculates num_output_accumulators.
+    for (int i = 0; i < num_data; i++) {
+      const int32 leaf = leaves_vec(i);
+      const int32 accumulator = node_map(leaf);
+      // Check for non-fertile node or fertile node that is already
+      // initialized.
+      if (accumulator >= 0 &&
+          !IsAllInitialized(
+              split_features.Slice(accumulator, accumulator + 1))) {
+        accumulator_to_leaves[accumulator].insert(i);
+      }
+    }
+
+    // Now we can allocate the outputs.
+    int num_output_accumulators = accumulator_to_leaves.size();
+    VLOG(1) << "num output accumulators = " << num_output_accumulators;
+    Tensor* accumulators_tensor = nullptr;
+    TensorShape accumulators_shape;
+    accumulators_shape.AddDim(num_output_accumulators);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, accumulators_shape,
+                                            &accumulators_tensor));
+    auto accumulators_flat = accumulators_tensor->tensor<int32, 1>();
+
+    Tensor* new_split_feature_rows_tensor = nullptr;
+    TensorShape new_split_feature_rows_shape;
+    new_split_feature_rows_shape.AddDim(num_output_accumulators);
+    new_split_feature_rows_shape.AddDim(num_splits);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, new_split_feature_rows_shape,
+                                            &new_split_feature_rows_tensor));
+    auto new_split_feature_rows_flat =
+        new_split_feature_rows_tensor->tensor<int32, 2>();
+
+    Tensor* new_split_threshold_rows_tensor = nullptr;
+    TensorShape new_split_threshold_rows_shape;
+    new_split_threshold_rows_shape.AddDim(num_output_accumulators);
+    new_split_threshold_rows_shape.AddDim(num_splits);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(2, new_split_threshold_rows_shape,
+                                            &new_split_threshold_rows_tensor));
+    auto new_split_threshold_rows_flat =
+        new_split_threshold_rows_tensor->tensor<float, 2>();
+
+    // The second pass fills out the outputs.
+    int output_slot = 0;
+    for (const auto& active : accumulator_to_leaves) {
+      const int32 accumulator = active.first;
+      const std::set<int32> inputs_for_accumulator = active.second;
+      VLOG(1) << "Accumulator " << accumulator
+                  << " gets new output slot " << output_slot;
+      accumulators_flat(output_slot) = accumulator;
+
+      // scatter_update updates entire rows, so we first copy the existing
+      // rows into the output tensors, and then write over the values we
+      // want to change.
+      for (int split = 0; split < num_splits; split++) {
+        new_split_feature_rows_flat(output_slot, split) =
+            features(accumulator, split);
+        new_split_threshold_rows_flat(output_slot, split) =
+            thresholds(accumulator, split);
+      }
+
+      for (const int32 i : inputs_for_accumulator) {
+        VLOG(2) << "Looking at data # " << i;
+
+        int32 num_inits = split_initializations_per_input_;
+        for (int split = 0; split < num_splits && num_inits > 0; split++) {
+          if (new_split_feature_rows_flat(output_slot, split) < 0) {
+            VLOG(1) << "Over-writing @ " << output_slot << "," << split;
+            const int32 index = rng_->Uniform(num_features);
+            new_split_feature_rows_flat(output_slot, split) = index;
+            new_split_threshold_rows_flat(output_slot, split) =
+                inputs(i, index);
+            --num_inits;
+          }
+        }
+      }
+      ++output_slot;
+    }
+  }
+
+ private:
+  int32 split_initializations_per_input_;
+  int32 split_sampling_random_seed_;
+  std::unique_ptr<random::PhiloxRandom> single_rand_;
+  std::unique_ptr<random::SimplePhilox> rng_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("SampleInputs").Device(DEVICE_CPU), SampleInputs);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc b/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc
new file mode 100644
index 0000000..a65a2c0
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/scatter_add_ndim_op.cc
@@ -0,0 +1,103 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// ScatterAddNdim implements a scatter_add that can operate on sparse
+// updates without being limited to the first dimension for indices.
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+
+
+namespace tensorflow {
+
+REGISTER_OP("ScatterAddNdim")
+  .Input("input: Ref(float)")
+  .Input("indices: int32")
+  .Input("deltas: float")
+
+  .Doc(R"doc(
+  Add elements in deltas to mutable input according to indices.
+
+  input: A N-dimensional float tensor to mutate.
+  indices:= A 2-D int32 tensor. The size of dimension 0 is the number of
+    deltas, the size of dimension 1 is the rank of the input.  `indices[i]`
+    gives the coordinates of input that `deltas[i]` should add to
+  deltas: `deltas[i]` is the value to add to input at index indices[i][:]
+)doc");
+
+
+class ScatterAddNdim : public OpKernel {
+ public:
+  explicit ScatterAddNdim(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    Tensor input_tensor = context->mutable_input(0, false);
+    const Tensor& indices_tensor = context->input(1);
+    const Tensor& deltas_tensor = context->input(2);
+
+    OP_REQUIRES(context, deltas_tensor.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "deltas should be one-dimensional"));
+    if (indices_tensor.shape().dim_size(0) > 0) {
+      OP_REQUIRES(context, indices_tensor.shape().dims() == 2,
+                  errors::InvalidArgument(
+                      "indices should be two-dimensional"));
+      OP_REQUIRES(
+          context,
+          indices_tensor.shape().dim_size(1) == input_tensor.shape().dims(),
+          errors::InvalidArgument(
+              "Number of indices dimensions should be the same as input "
+              "rank."));
+      OP_REQUIRES(
+          context,
+          indices_tensor.shape().dim_size(0) ==
+          deltas_tensor.shape().dim_size(0),
+          errors::InvalidArgument(
+              "Number of updates should be same as number of indices."));
+    } else {
+      return;
+    }
+
+    auto input = input_tensor.flat<float>();
+
+    const auto indices = indices_tensor.tensor<int32, 2>();
+    const auto deltas = deltas_tensor.unaligned_flat<float>();
+
+    const int32 num_dims = indices_tensor.shape().dim_size(1);
+
+    // Calculate index multipliers.
+    std::vector<int32> multipliers;
+    int32 last_size = input.size();
+
+    for (int32 j = 0; j < num_dims; j++) {
+      const int32 m = last_size / input_tensor.shape().dim_size(j);
+      multipliers.push_back(m);
+      last_size = m;
+    }
+
+    // Perform updates.
+    for (int32 i = 0; i < indices_tensor.shape().dim_size(0); i++) {
+      int32 index = 0;
+      for (int32 j = 0; j < num_dims; j++) {
+        index += indices(i, j) * multipliers[j];
+      }
+      input(index) += deltas(i);
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ScatterAddNdim").Device(DEVICE_CPU),
+                        ScatterAddNdim);
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
new file mode 100644
index 0000000..3e84534
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
@@ -0,0 +1,170 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// TreePredictions returns the per-class probabilities for each input by
+// evaluating the given tree.
+#include <algorithm>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+using tensorforest::CHILDREN_INDEX;
+using tensorforest::FEATURE_INDEX;
+using tensorforest::LEAF_NODE;
+using tensorforest::FREE_NODE;
+
+using tensorforest::DecideNode;
+using tensorforest::Sum;
+
+REGISTER_OP("TreePredictions")
+  .Attr("valid_leaf_threshold: float")
+  .Input("input_data: float")
+  .Input("tree: int32")
+  .Input("tree_thresholds: float")
+  .Input("node_per_class_weights: float")
+
+  .Output("predictions: float")
+  .Doc(R"doc(
+  Returns the per-class probabilities for each input.
+
+  input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+   gives the j-th feature of the i-th input.
+  tree:= A 2-d int32 tensor.  `tree[0][i]` gives the index of the left child
+   of the i-th node, `tree[0][i] + 1` gives the index of the right child of
+   the i-th node, and `tree[1][i]` gives the index of the feature used to
+   split the i-th node.
+  tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
+   node.
+  node_per_class_weights: `node_per_class_weights[n][c]` records how many
+   training examples have class c and have ended up in node n.
+  predictions: `predictions[i][j]` is the probability that input i is class j.
+  valid_leaf_threshold: Minimum number of samples that have arrived to a leaf
+    to be considered a valid leaf, otherwise use the parent.
+)doc");
+
+
+class TreePredictions : public OpKernel {
+ public:
+  explicit TreePredictions(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(
+      "valid_leaf_threshold", &valid_leaf_threshold_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_data = context->input(0);
+
+    const Tensor& tree_tensor = context->input(1);
+    const Tensor& tree_thresholds = context->input(2);
+    const Tensor& node_per_class_weights = context->input(3);
+
+    OP_REQUIRES(context, tree_tensor.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "tree should be two-dimensional"));
+    OP_REQUIRES(context, tree_thresholds.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "tree_threhsolds should be one-dimensional"));
+    OP_REQUIRES(context, node_per_class_weights.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "node_pcw should be two-dimensional"));
+
+    if (input_data.shape().dim_size(0) > 0) {
+      OP_REQUIRES(context, input_data.shape().dims() == 2,
+                  errors::InvalidArgument(
+                      "input_data should be two-dimensional"));
+    }
+    OP_REQUIRES(
+        context,
+        tree_tensor.shape().dim_size(0) ==
+        tree_thresholds.shape().dim_size(0) &&
+        tree_tensor.shape().dim_size(0) ==
+        node_per_class_weights.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of nodes should be the same in "
+            "tree, tree_thresholds and node_pcw."));
+
+    const int32 num_classes = node_per_class_weights.shape().dim_size(1);
+    const int32 num_data = input_data.shape().dim_size(0);
+
+    Tensor* output_predictions = nullptr;
+    TensorShape output_shape;
+    output_shape.AddDim(num_data);
+    output_shape.AddDim(num_classes);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, output_shape,
+                                            &output_predictions));
+    auto out = output_predictions->tensor<float, 2>();
+
+    const auto node_pcw = node_per_class_weights.tensor<float, 2>();
+    const auto tree = tree_tensor.tensor<int32, 2>();
+    const auto thresholds = tree_thresholds.unaligned_flat<float>();
+
+    for (int i = 0; i < num_data; i++) {
+      const Tensor point = input_data.Slice(i, i+1);
+      int node_index = 0;
+      int parent = -1;
+      while (true) {
+        const int32 left_child = tree(node_index, CHILDREN_INDEX);
+        if (left_child == LEAF_NODE) {
+          float sum = Sum<float>(node_per_class_weights.Slice(
+              node_index, node_index + 1));
+          float parent_weight = 0.0;
+          if (sum < valid_leaf_threshold_ && parent >= 0) {
+            VLOG(1) << "not enough samples at leaf, including parent counts."
+                    << "child sum = " << sum;
+            float parent_sum = Sum<float>(node_per_class_weights.Slice(
+                parent, parent + 1));
+            // Weight the parent's counts just enough so that the new sum is
+            // valid_leaf_threshold_, but never give any counts a weight of
+            // more than 1.
+            parent_weight = std::min(1.0f,
+                                (valid_leaf_threshold_ - sum) / parent_sum);
+            sum += parent_weight * parent_sum;
+            VLOG(1) << "Sum w/ parent included = " << sum;
+          }
+          for (int c = 0; c < num_classes; c++) {
+            float w = node_pcw(node_index, c);
+            if (parent_weight > 0.0) {
+              w += parent_weight * node_pcw(parent, c);
+            }
+            out(i, c) = w / sum;
+          }
+          break;
+        } else if (left_child == FREE_NODE) {
+          LOG(ERROR) << "Reached a free node, not good.";
+          return;
+        }
+        parent = node_index;
+        node_index = left_child +
+            DecideNode(point, tree(node_index, FEATURE_INDEX),
+                       thresholds(node_index));
+      }
+    }
+
+    VLOG(1) << "tree: " << tree;
+    VLOG(1) << "output: " << out;
+  }
+
+ private:
+  float valid_leaf_threshold_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("TreePredictions").Device(DEVICE_CPU),
+                        TreePredictions);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
new file mode 100644
index 0000000..df50a3d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
@@ -0,0 +1,67 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+using tensorflow::Tensor;
+
+int32 BestFeature(const Tensor& total_counts, const Tensor& split_counts,
+                  int32 accumulator) {
+  int32 best_feature_index = -1;
+  // We choose the split with the lowest score.
+  float best_score = kint64max;
+  const int32 num_splits = split_counts.shape().dim_size(1);
+  const int32 num_classes = split_counts.shape().dim_size(2);
+  // Ideally, Eigen::Tensor::chip would be best to use here but it results
+  // in seg faults, so we have to go with flat views of these tensors.  However,
+  // it is still pretty efficient because we put off evaluation until the
+  // score is actually returned.
+  const auto tc = total_counts.Slice(
+      accumulator, accumulator + 1).unaligned_flat<float>();
+  const auto splits = split_counts.Slice(
+      accumulator, accumulator + 1).unaligned_flat<float>();
+  Eigen::array<int, 1> bcast({num_splits});
+  const auto rights = tc.broadcast(bcast) - splits;
+
+  for (int i = 0; i < num_splits; i++) {
+    Eigen::array<int, 1> offsets = {i * num_classes};
+    Eigen::array<int, 1> extents = {num_classes};
+    float score = WeightedGiniImpurity(splits.slice(offsets, extents)) +
+        WeightedGiniImpurity(rights.slice(offsets, extents));
+
+    if (score < best_score) {
+      best_score = score;
+      best_feature_index = i;
+    }
+  }
+  return best_feature_index;
+}
+
+bool DecideNode(const Tensor& point, int32 feature, float bias) {
+  const auto p = point.unaligned_flat<float>();
+  return p(feature) > bias;
+}
+
+bool IsAllInitialized(const Tensor& features) {
+  const auto feature_vec = features.unaligned_flat<int32>();
+  return feature_vec(feature_vec.size() - 1) >= 0;
+}
+
+
+}  // namespace tensorforest
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
new file mode 100644
index 0000000..9b7553e
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
@@ -0,0 +1,89 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef LEARNING_LIB_TENSOR_FOREST_V2_TREE_UTILS_H_
+#define LEARNING_LIB_TENSOR_FOREST_V2_TREE_UTILS_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+// Indexes in the tree representation's 2nd dimension for children and features.
+const int32 CHILDREN_INDEX = 0;
+const int32 FEATURE_INDEX = 1;
+
+// Used in the tree's children sub-tensor to indicate leaf and free nodes.
+const int32 LEAF_NODE = -1;
+const int32 FREE_NODE = -2;
+
+// Calculates the sum of a tensor.
+template<typename T>
+T Sum(tensorflow::Tensor counts) {
+  Eigen::Tensor<T, 0, Eigen::RowMajor> count_sum =
+      counts.unaligned_flat<T>().sum();
+  return count_sum(0);
+}
+
+// Given an Eigen::Tensor type, calculate the Gini impurity, which we use
+// to determine the best split (lowest) and which nodes to allocate first
+// (highest).
+template<typename T>
+int32 WeightedGiniImpurity(const T& counts) {
+  // Our split score is the Gini impurity times the number of examples
+  // seen by the leaf.  If c(i) denotes the i-th class count and c = sum_i c(i)
+  // then
+  // score = c * (1 - sum_i ( c(i) / c )^2 )
+  //       = c - sum_i c(i)^2 / c
+  const auto smoothed = counts + counts.constant(1.0f);
+  const auto sum = smoothed.sum();
+  const auto sum2 = smoothed.square().sum();
+  Eigen::Tensor<float, 0, Eigen::RowMajor> ret = sum - (sum2 / sum);
+  return ret(0);
+}
+
+// Returns the best split to use based on the (lowest) Gini impurity.
+// Takes in the whole total and per-split count tensors because using
+// Tensor::Slice returns a tensor of the same dimensionality, which makes
+// things a little awkward.
+// TODO(gilberth): Currently test_util.BestFeatureToSplit doesn't work with
+// this code because the shapes of the incoming tensors are different than
+// in v1.  Try to make it work for both versions?
+int32 BestFeature(const tensorflow::Tensor& total_counts,
+                  const tensorflow::Tensor& split_counts,
+                  int32 accumulator);
+
+// Initializes everything in the given tensor to the given value.
+template <typename T>
+void Initialize(tensorflow::Tensor counts, T val = 0) {
+  auto flat = counts.unaligned_flat<T>();
+  std::fill(flat.data(), flat.data() + flat.size(), val);
+}
+
+// Returns true if the point falls to the right (i.e., the selected feature
+// of the input point is greater than the bias threshold), and false if it
+// falls to the left.
+bool DecideNode(const tensorflow::Tensor& point, int32 feature, float bias);
+
+// Returns true if all the splits are initialized. Since they get initialized
+// in order, we can simply infer this from the last split.
+// This should only be called for a single allocator's candidate features
+// (i.e. candidate_split_features.Slice(accumulator, accumulator + 1) ).
+bool IsAllInitialized(const tensorflow::Tensor& features);
+
+}  // namespace tensorforest
+} // namespace tensorflow
+
+#endif  // LEARNING_LIB_TENSOR_FOREST_V2_TREE_UTILS_H_
diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
new file mode 100644
index 0000000..0bf7525
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
@@ -0,0 +1,407 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// UpdateFertileSlots manages accumulator slots.  It assigns free or newly
+// finished accumulator slots to waiting non-fertile nodes and new leaves
+// according to their existing split scores (based on node pcws).  It does not
+// allocate slots to leaves that are beyond max depth.
+#include <unordered_map>
+#include <set>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/top_n.h"
+
+
+namespace tensorflow {
+
+using gtl::TopN;
+using tensorforest::Initialize;
+using tensorforest::WeightedGiniImpurity;
+
+
+REGISTER_OP("UpdateFertileSlots")
+  .Attr("max_depth: int32")
+  .Input("finished: int32")
+  .Input("non_fertile_leaves: int32")
+  .Input("non_fertile_leaf_scores: float")
+  .Input("end_of_tree: int32")
+  .Input("tree_depths: int32")
+  .Input("pcw_candidate_splits: float")
+  .Input("pcw_total_splits: float")
+  .Input("node_to_accumulator: int32")
+  .Output("node_map_updates: int32")
+  .Output("accumulators_cleared: int32")
+  .Output("accumulators_allocated: int32")
+  .Output("new_nonfertile_leaves: int32")
+  .Output("new_nonfertile_leaves_scores: float")
+  .Doc(R"doc(
+  Updates accumulator slots to reflect finished or newly fertile nodes.
+
+  Leaves at the depth of the attribute `max_depth` won't be made fertile
+  (i.e., won't be given an accumulator slot.)
+
+  finished:= A 1-d int32 tensor containing the indices of fertile nodes that
+    are ready to decide on a split.
+  non_fertile_leaves:= A 1-d int32 tensor containing the indices of all the
+    currently non-fertile leaves.  If there are free accumulator slots after
+    deallocation, UpdateFertileSlots will consider these nodes (plus the ones
+    in new_leaves) and potentially turn some of them fertile.
+  non_fertile_leaf_scores: `non_fertile_leaf_scores[i]` is the splitting score
+    of the non-fertile leaf `non_fertile_leaves[i]`.
+  end_of_tree: The end of tree tensor from the previous training iteration, used
+    with the finished input to calculate a list of new leaf indices created by
+    GrowTree, which will be considered to become fertile if there are free
+    slots.
+  tree_depths: `tree_depths[i]` is the depth in the tree of node i.
+  pcw_candidate_splits: `pcw_candidate_splits[a][s][c]` records how many
+    training examples have class c and have ended up in the fertile node
+    associated with accumulator slot a and then taken the *left* branch of
+    candidate split s.
+  pcw_total_splits: `pcw_total_splits[a][c]` records how many training examples
+    have class c and have ended up in the fertile node associated with
+    accumulator slot a.  Between that and `pcw_candidate_splits`, the number of
+    examples taking the right branch of a split can be reconstructed.
+  node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+    fertile node i, or -1 if node i isn't fertile.
+  node_map_updates:= A 2-d int32 tensor describing the changes that need to
+    be applied to the node_to_accumulator map.  Intended to be used with
+    `tf.scatter_update(node_to_accumulator,
+                       node_map_updates[0],
+                       node_map_updates[1])`.
+  accumulators_cleared:= A 1-d int32 tensor containing the indices of all
+    the accumulator slots that need to be cleared.
+  accumulators_allocated:= A 1-d int32 tensor containing the indices of all
+    the accumulator slots that need to be allocated.
+  new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the
+    leaves that are now non-fertile.
+  new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the
+    splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`.
+)doc");
+
+
+class UpdateFertileSlots : public OpKernel {
+ public:
+  explicit UpdateFertileSlots(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(
+      "max_depth", &max_depth_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& finished = context->input(0);
+
+    const Tensor& non_fertile_leaves =  context->input(1);
+    const Tensor& non_fertile_leaf_scores =  context->input(2);
+    const Tensor& end_of_tree = context->input(3);
+    const Tensor& tree_depths = context->input(4);
+
+    const Tensor& pcw_candidate_splits = context->input(5);
+    const Tensor& pcw_total_splits = context->input(6);
+    const Tensor& node_to_accumulator = context->input(7);
+
+    OP_REQUIRES(context, finished.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "finished should be one-dimensional"));
+    OP_REQUIRES(context, non_fertile_leaves.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "non_fertile_leaves should be one-dimensional"));
+    OP_REQUIRES(context, non_fertile_leaf_scores.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "non_fertile_leaves_scores should be one-dimensional"));
+    OP_REQUIRES(context, end_of_tree.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "end_of_tree should be one-dimensional"));
+    OP_REQUIRES(context, tree_depths.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "tree_depths should be one-dimensional"));
+    OP_REQUIRES(context, pcw_candidate_splits.shape().dims() == 3,
+                errors::InvalidArgument(
+                    "pcw_candidate_splits should be three-dimensional"));
+    OP_REQUIRES(context, pcw_total_splits.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "pcw_total_splits should be two-dimensional"));
+     OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "node_to_accumulator should be one-dimensional"));
+
+    OP_REQUIRES(
+        context,
+        pcw_candidate_splits.shape().dim_size(0) ==
+        pcw_total_splits.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of accumulators should be the same in pcw_candidate_splits "
+            "and pcw_total_splits."));
+    OP_REQUIRES(
+        context,
+        non_fertile_leaves.shape().dim_size(0) ==
+        non_fertile_leaf_scores.shape().dim_size(0),
+        errors::InvalidArgument(
+            "Number of non fertile leaves should be the same in "
+            "non_fertile_leaves and non_fertile_leaf_scores."));
+
+    // Read finished accumulators into a set for quick lookup.
+    const auto node_map = node_to_accumulator.unaligned_flat<int32>();
+    const auto finished_vec = finished.unaligned_flat<int32>();
+    std::set<int32> finished_accumulators;
+    for (int32 i = 0; i < finished_vec.size(); ++i) {
+      finished_accumulators.insert(node_map(finished_vec(i)));
+    }
+
+    // Construct leaf heap to sort leaves to allocate accumulators to.
+    const auto eot = end_of_tree.unaligned_flat<int32>();
+    const int32 num_nodes = tree_depths.shape().dim_size(0);
+    const int32 num_finished = finished.shape().dim_size(0);
+    const int32 num_new_leaves = std::min(num_finished * 2, num_nodes - eot(0));
+
+    LeafHeapType leaf_heap(
+        non_fertile_leaves.shape().dim_size(0) +
+        num_new_leaves, OrderBySecondGreater());
+    ConstructLeafHeap(
+        non_fertile_leaves, non_fertile_leaf_scores, tree_depths,
+        eot(0), num_new_leaves, pcw_total_splits.shape().dim_size(1),
+        &leaf_heap);
+
+    // Allocate leaves.
+    std::unique_ptr<HeapValuesType> values(
+        leaf_heap.Extract());
+    int32 accumulator = -1;  // This will first get incremented to 0.
+    int32 num_accumulators_allocated = 0;
+    std::unordered_map<int32, int32> accumulators_to_node;
+    FindNextAccumulator(pcw_total_splits, finished_accumulators, &accumulator);
+    int i = 0;
+    for (; i < values->size(); ++i) {
+      const std::pair<int32, float>& node = (*values)[i];
+      if (accumulator < 0) {
+        VLOG(1) << "No allocators left.";
+        break;
+      }
+      VLOG(1) << "setting node " << node.first << " to accumulator "
+              << accumulator;
+      ++num_accumulators_allocated;
+      accumulators_to_node[accumulator] = node.first;
+
+      FindNextAccumulator(pcw_total_splits, finished_accumulators,
+                          &accumulator);
+    }
+
+    // Construct and fill outputs.
+    SetNodeMapUpdates(accumulators_to_node, finished, context);
+    SetAccumulatorsCleared(finished_accumulators,
+                           accumulators_to_node, context);
+    SetAccumulatorsAllocated(accumulators_to_node, context);
+    SetNewNonFertileLeaves(values.get(), i, context);
+  }
+
+ private:
+  struct OrderBySecondGreater {
+    bool operator()(const std::pair<int32, float> &left,
+                    const std::pair<int32, float> &right) {
+        return left.second > right.second;
+    }
+  };
+
+  typedef TopN<std::pair<int32, float>, OrderBySecondGreater> LeafHeapType;
+  typedef std::vector<std::pair<int32, float>> HeapValuesType;
+
+  // Creates an update tensor for node to accumulator map.  Sets finished nodes
+  // to -1 (no accumulator assigned) and newly allocated nodes to their
+  // accumulator.
+  void SetNodeMapUpdates(
+      const std::unordered_map<int32, int32>& accumulators_to_node,
+      const Tensor& finished, OpKernelContext* context) {
+    // Node map updates.
+    Tensor* output_node_map = nullptr;
+    TensorShape node_map_shape;
+    node_map_shape.AddDim(2);
+    node_map_shape.AddDim(accumulators_to_node.size() +
+                          finished.shape().dim_size(0));
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, node_map_shape,
+                                            &output_node_map));
+
+    auto out_node = output_node_map->tensor<int32, 2>();
+    int32 output_slot = 0;
+
+    // Set finished nodes to -1.
+    const auto finished_vec = finished.unaligned_flat<int32>();
+    for (int32 i = 0; i < finished_vec.size(); ++i) {
+      out_node(0, output_slot) = finished_vec(i);
+      out_node(1, output_slot)  = -1;
+      ++output_slot;
+    }
+
+    // Set newly allocated nodes to their allocator.
+    for (const auto& node_alloc_pair : accumulators_to_node) {
+      out_node(0, output_slot) = node_alloc_pair.second;
+      out_node(1, output_slot) = node_alloc_pair.first;
+      ++output_slot;
+    }
+  }
+
+  // Creates output tensor for cleared accumulators. Cleared accumulators are
+  // those that were finished but not re-allocated.
+  void SetAccumulatorsCleared(
+      const std::set<int32>& finished_accumulators,
+      const std::unordered_map<int32, int32>& accumulators_to_node,
+      OpKernelContext* context) {
+    std::set<int32> cleared;
+    for (const int32 node : finished_accumulators) {
+      if (accumulators_to_node.find(node) == accumulators_to_node.end()) {
+        cleared.insert(node);
+      }
+    }
+
+    Tensor* output_cleared = nullptr;
+    TensorShape cleared_shape;
+    cleared_shape.AddDim(cleared.size());
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, cleared_shape,
+                                            &output_cleared));
+
+    auto out = output_cleared->unaligned_flat<int32>();
+
+    int32 i = 0;
+    for (const int32 accumulator : cleared) {
+      out(i) = accumulator;
+      ++i;
+    }
+  }
+
+  // Creates output tensor for accumulators that were allocated to now-fertile
+  // nodes.
+  void SetAccumulatorsAllocated(
+      const std::unordered_map<int32, int32>& accumulators_to_node,
+      OpKernelContext* context) {
+    // Node map updates.
+    Tensor* output_allocated = nullptr;
+    TensorShape allocated_shape;
+    allocated_shape.AddDim(accumulators_to_node.size());
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(2, allocated_shape,
+                                            &output_allocated));
+
+    auto out = output_allocated->unaligned_flat<int32>();
+    int32 output_slot = 0;
+
+    // Set newly allocated nodes to their allocator.
+    for (const auto& node_alloc_pair : accumulators_to_node) {
+      out(output_slot) = node_alloc_pair.first;
+      ++output_slot;
+    }
+  }
+
+  // Creates output tensors for non-fertile leaves and non-fertile leaf scores.
+  // Start indicates the index in values where the leaves that weren't
+  // allocated this round begin, and should thus be placed in the new
+  // nonfertile_leaves tensors.
+  void SetNewNonFertileLeaves(HeapValuesType* values, int start,
+                              OpKernelContext* context) {
+    // Node map updates.
+    int32 num_values = values->size() - start;
+
+    // Unfortunately, a zero-sized Variable results in an uninitialized
+    // error, probably because they check for zero size instead of
+    // a real inititalization condition.
+    bool fill_with_garbage = false;
+    if (num_values == 0) {
+      num_values = 1;
+      fill_with_garbage = true;
+    }
+    Tensor* output_nonfertile_leaves = nullptr;
+    TensorShape nonfertile_leaves_shape;
+    nonfertile_leaves_shape.AddDim(num_values);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(3, nonfertile_leaves_shape,
+                                            &output_nonfertile_leaves));
+
+    auto out_nonfertile_leaves =
+        output_nonfertile_leaves->unaligned_flat<int32>();
+
+    Tensor* output_nonfertile_leaves_scores = nullptr;
+    TensorShape nonfertile_leaves_scores_shape;
+    nonfertile_leaves_scores_shape.AddDim(num_values);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(4, nonfertile_leaves_scores_shape,
+                                            &output_nonfertile_leaves_scores));
+
+    auto out_nonfertile_leaves_scores =
+        output_nonfertile_leaves_scores->unaligned_flat<float>();
+
+    if (fill_with_garbage) {
+      out_nonfertile_leaves(0) = -1;
+      out_nonfertile_leaves_scores(0) = 0.0;
+      return;
+    }
+
+    for (int32 i = start; i < values->size(); ++i) {
+      const std::pair<int32, float>& node = (*values)[i];
+      out_nonfertile_leaves(i -start) = node.first;
+      out_nonfertile_leaves_scores(i - start) = node.second;
+    }
+  }
+
+  void ConstructLeafHeap(
+      const Tensor& non_fertile_leaves, const Tensor& non_fertile_leaf_scores,
+      const Tensor& tree_depths, int32 end_of_tree, int32 num_new_leaves,
+      int32 num_classes, LeafHeapType* leaf_heap) {
+    const auto leaf_vec = non_fertile_leaves.unaligned_flat<int32>();
+    const auto leaf_score_vec = non_fertile_leaf_scores.unaligned_flat<float>();
+    const auto depths = tree_depths.unaligned_flat<int32>();
+
+    for (int i = 0; i < leaf_vec.size(); i++) {
+      const int32 leaf = leaf_vec(i);
+      // Filter out leaves < 0, non_fertile_nodes can contain garbage at
+      // startup.
+      if (leaf >= 0 && depths(leaf) < max_depth_) {
+        leaf_heap->push(std::make_pair(leaf, leaf_score_vec(i)));
+      }
+    }
+
+    // Add new leaves.
+    Eigen::Tensor<float, 1, 1> zeros(num_classes);
+    zeros.setZero();
+    const float zero_score = WeightedGiniImpurity(zeros);
+    for (int leaf = end_of_tree; leaf < end_of_tree + num_new_leaves; leaf++) {
+      if (depths(leaf) < max_depth_) {
+        leaf_heap->push(std::make_pair(leaf, zero_score));
+      }
+    }
+  }
+
+  // Finds the next available or newly-finished accumulator.
+  void FindNextAccumulator(Tensor totals_tensor,
+                           const std::set<int32>& finished_accumulators,
+                           int* current) {
+    ++(*current);
+    const auto totals = totals_tensor.tensor<float, 2>();
+    for (; *current < totals_tensor.shape().dim_size(0); ++(*current)) {
+      if (totals(*current, 0) < 0 ||
+          finished_accumulators.find(*current) != finished_accumulators.end()) {
+        return;
+      }
+    }
+    *current = -1;
+  }
+
+  int32 max_depth_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("UpdateFertileSlots").Device(DEVICE_CPU),
+                        UpdateFertileSlots);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
new file mode 100644
index 0000000..ead6198
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.best_splits_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow  # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BestSplitsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.finished = [3, 5]
+    self.node_map = [-1, -1, -1, 0, -1, 3, -1, -1, -1]
+    self.candidate_counts = [[[50., 60., 40., 3.], [70., 30., 70., 30.]],
+                             [[0., 0., 0., 0.], [0., 0., 0., 0.]],
+                             [[0., 0., 0., 0.], [0., 0., 0., 0.]],
+                             [[10., 10., 10., 10.], [10., 5., 5., 10.]]]
+    self.total_counts = [[100., 100., 100., 100.],
+                         [0., 0., 0., 0.],
+                         [0., 0., 0., 0.],
+                         [100., 100., 100., 100.]]
+    self.ops = training_ops.Load()
+
+  def testSimple(self):
+    with self.test_session():
+      split_indices = self.ops.best_splits(
+          self.finished, self.node_map, self.candidate_counts,
+          self.total_counts)
+
+      self.assertAllEqual([0, 1], split_indices.eval())
+
+  def testNoFinished(self):
+    with self.test_session():
+      split_indices = self.ops.best_splits(
+          [], self.node_map, self.candidate_counts, self.total_counts)
+
+      self.assertAllEqual([], split_indices.eval())
+
+  def testBadInput(self):
+    del self.total_counts[1]
+
+    with self.test_session():
+      with self.assertRaisesOpError(
+          'Number of accumulators should be the same in pcw_candidate_splits '
+          'and pcw_total_splits.'):
+        self.ops.best_splits(
+            self.finished, self.node_map, self.candidate_counts,
+            self.total_counts).eval()
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
new file mode 100644
index 0000000..e93bb17
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
@@ -0,0 +1,94 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.count_extremely_random_stats."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow  # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.input_data = [[-1., 0.], [-1., 2.],  # node 1
+                       [1., 0.], [1., -2.]]  # node 2
+    self.input_labels = [0, 1, 2, 3]
+    self.tree = [[1, 0], [-1, 0], [-1, 0]]
+    self.tree_thresholds = [0., 0., 0.]
+    self.node_map = [-1, 0, -1]
+    self.split_features = [[1], [-1]]
+    self.split_thresholds = [[1.], [0.]]
+    self.ops = training_ops.Load()
+
+  def testSimple(self):
+    with self.test_session():
+      (pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
+       pcw_totals_delta, leaves) = (
+           self.ops.count_extremely_random_stats(
+               self.input_data, self.input_labels, self.tree,
+               self.tree_thresholds, self.node_map,
+               self.split_features, self.split_thresholds, num_classes=4))
+
+      self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
+                           [0., 0., 1., 1.]],
+                          pcw_node.eval())
+      self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval())
+      self.assertAllEqual([1.], pcw_splits_delta.eval())
+      self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval())
+      self.assertAllEqual([1., 1.], pcw_totals_delta.eval())
+      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+
+  def testNoAccumulators(self):
+    with self.test_session():
+      (pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
+       pcw_totals_delta, leaves) = (
+           self.ops.count_extremely_random_stats(
+               self.input_data, self.input_labels, self.tree,
+               self.tree_thresholds, [-1] * 3,
+               self.split_features, self.split_thresholds, num_classes=4))
+
+      self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
+                           [0., 0., 1., 1.]],
+                          pcw_node.eval())
+      self.assertEquals((0, 3), pcw_splits_indices.eval().shape)
+      self.assertAllEqual([], pcw_splits_delta.eval())
+      self.assertEquals((0, 2), pcw_totals_indices.eval().shape)
+      self.assertAllEqual([], pcw_totals_delta.eval())
+      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+
+  def testBadInput(self):
+    del self.node_map[-1]
+
+    with self.test_session():
+      with self.assertRaisesOpError(
+          'Number of nodes should be the same in '
+          'tree, tree_thresholds, and node_to_accumulator'):
+        pcw_node, _, _, _, _, _ = (
+            self.ops.count_extremely_random_stats(
+                self.input_data, self.input_labels, self.tree,
+                self.tree_thresholds, self.node_map,
+                self.split_features, self.split_thresholds, num_classes=4))
+
+        self.assertAllEqual([], pcw_node.eval())
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
new file mode 100644
index 0000000..84583ae
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
@@ -0,0 +1,63 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.finished_nodes_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow  # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class FinishedNodesTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.leaves = [1, 3, 4]
+    self.node_map = [-1, -1, -1, 0, 1, -1]
+    self.pcw_total_splits = [[3, 3], [4, 7], [0, 0], [0, 0], [0, 0]]
+    self.ops = training_ops.Load()
+
+  def testSimple(self):
+    with self.test_session():
+      finished = self.ops.finished_nodes(self.leaves, self.node_map,
+                                         self.pcw_total_splits,
+                                         num_split_after_samples=10)
+
+      self.assertAllEqual([4], finished.eval())
+
+  def testNoAccumulators(self):
+    with self.test_session():
+      finished = self.ops.finished_nodes(self.leaves, [-1] * 6,
+                                         self.pcw_total_splits,
+                                         num_split_after_samples=10)
+
+      self.assertAllEqual([], finished.eval())
+
+  def testBadInput(self):
+    with self.test_session():
+      with self.assertRaisesOpError(
+          'leaf_tensor should be one-dimensional'):
+        finished = self.ops.finished_nodes([self.leaves], self.node_map,
+                                           self.pcw_total_splits,
+                                           num_split_after_samples=10)
+
+        self.assertAllEqual([], finished.eval())
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/grow_tree_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/grow_tree_op_test.py
new file mode 100644
index 0000000..8177e2d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/grow_tree_op_test.py
@@ -0,0 +1,105 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.grow_tree_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class GrowTreeTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.tree = tf.Variable([[1, 0], [-1, 0], [-1, 0],
+                             [-2, 0], [-2, 0], [-2, 0], [-2, 0]])
+    self.tree_thresholds = tf.Variable([0., 0., 0., 0., 0., 0., 0.])
+    self.eot = tf.Variable([3])
+    self.depths = tf.Variable([1, 2, 2, -1, -1, -1, -1])
+    self.node_map = [-1, 0, 1, -1, -1, -1, -1]
+    self.finished = [1, 2]
+    self.best_splits = [2, 3]
+    self.split_features = [[1, 2, 3, 4], [5, 6, 7, 8]]
+    self.split_thresholds = [[10., 20., 30., 40.], [50., 60., 70., 80.]]
+    self.ops = training_ops.Load()
+
+  def testSimple(self):
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      update_list, tree_updates, threshold_updates, depth_updates, new_eot = (
+          self.ops.grow_tree(self.eot, self.depths, self.node_map,
+                             self.finished, self.best_splits,
+                             self.split_features, self.split_thresholds))
+
+      self.assertAllEqual([1, 3, 4, 2, 5, 6], update_list.eval())
+      self.assertAllEqual(
+          [[3, 3], [-1, -1], [-1, -1], [5, 8], [-1, -1], [-1, -1]],
+          tree_updates.eval())
+      self.assertAllEqual([30.0, 0.0, 0.0, 80.0, 0.0, 0.0],
+                          threshold_updates.eval())
+      self.assertAllEqual([2, 3, 3, 2, 3, 3], depth_updates.eval())
+      self.assertAllEqual([7], new_eot.eval())
+
+  def testNoRoomToGrow(self):
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      # Even though there's one free node, there needs to be 2 to grow.
+      tf.assign(self.eot, [6]).eval()
+
+      update_list, tree_updates, threshold_updates, depth_updates, new_eot = (
+          self.ops.grow_tree(self.eot, self.depths, self.node_map,
+                             self.finished, self.best_splits,
+                             self.split_features, self.split_thresholds))
+
+      self.assertAllEqual([], update_list.eval())
+      self.assertEquals((0, 2), tree_updates.eval().shape)
+      self.assertAllEqual([], threshold_updates.eval())
+      self.assertAllEqual([], depth_updates.eval())
+      self.assertAllEqual([6], new_eot.eval())
+
+  def testNoFinished(self):
+    with self.test_session():
+      tf.initialize_all_variables().run()
+
+      update_list, tree_updates, threshold_updates, depth_updates, new_eot = (
+          self.ops.grow_tree(self.eot, self.depths, self.node_map, [], [],
+                             self.split_features, self.split_thresholds))
+
+      self.assertAllEqual([], update_list.eval())
+      self.assertAllEqual((0, 2), tree_updates.eval().shape)
+      self.assertAllEqual([], threshold_updates.eval())
+      self.assertAllEqual([], depth_updates.eval())
+      self.assertAllEqual([3], new_eot.eval())
+
+  def testBadInput(self):
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      with self.assertRaisesOpError(
+          'Number of finished nodes should be the same in finished and '
+          'best_splits.'):
+        update_list, _, _, _, _ = (
+            self.ops.grow_tree(self.eot, self.depths, self.node_map,
+                               [], self.best_splits,
+                               self.split_features, self.split_thresholds))
+        self.assertAllEqual([], update_list.eval())
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
new file mode 100644
index 0000000..d050747
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
@@ -0,0 +1,79 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.sample_inputs_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class SampleInputsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.input_data = [[-1., 10.], [-10., 2.],  # node 1
+                       [20., 50.], [1., -2.]]  # node 2
+    self.node_map = [-1, 0, 1]
+    self.leaves = [1, 1, 2, 2]
+    self.split_features = [[-1, -1, -1], [1, 0, -1], [-1, -1, -1]]
+    self.split_thresholds = [[0., 0., 0.], [5., -2., 0.], [0., 0., 0.]]
+    self.ops = training_ops.Load()
+
+  def testSimple(self):
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      indices, feature_updates, threshold_updates = (
+          self.ops.sample_inputs(
+              self.input_data, self.node_map, self.leaves, self.split_features,
+              self.split_thresholds, split_initializations_per_input=1,
+              split_sampling_random_seed=3))
+      self.assertAllEqual([1, 0], indices.eval())
+      self.assertAllEqual([[1, 0, 1], [0, 0, -1]],
+                          feature_updates.eval())
+      self.assertAllEqual([[5., -2., 50.], [-1., -10., 0.]],
+                          threshold_updates.eval())
+
+  def testNoAccumulators(self):
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      indices, feature_updates, threshold_updates = (
+          self.ops.sample_inputs(
+              self.input_data, [-1] * 3, self.leaves, self.split_features,
+              self.split_thresholds, split_initializations_per_input=1,
+              split_sampling_random_seed=3))
+      self.assertAllEqual([], indices.eval())
+      self.assertAllEqual((0, 3), feature_updates.eval().shape)
+      self.assertAllEqual((0, 3), threshold_updates.eval().shape)
+
+  def testBadInput(self):
+    del self.split_features[1]
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      with self.assertRaisesOpError(
+          'split_features and split_thresholds should be the same shape.'):
+        indices, _, _ = self.ops.sample_inputs(
+            self.input_data, self.node_map, self.leaves, self.split_features,
+            self.split_thresholds, split_initializations_per_input=1,
+            split_sampling_random_seed=3)
+        self.assertAllEqual([], indices.eval())
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
new file mode 100644
index 0000000..467ffed
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -0,0 +1,82 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.scatter_add_ndim_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class ScatterAddNdimTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.ops = training_ops.Load()
+
+  def test1dim(self):
+    input_data = tf.Variable([1., 2., 3., 4., 5., 6.,
+                              7., 8., 9., 10., 11., 12.])
+    indices = [[1], [10]]
+    updates = [100., 200.]
+
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      self.ops.scatter_add_ndim(input_data, indices, updates).run()
+      self.assertAllEqual([1., 102., 3., 4., 5., 6.,
+                           7., 8., 9., 10., 211., 12.], input_data.eval())
+
+  def test3dim(self):
+    input_data = tf.Variable([[[1., 2., 3.], [4., 5., 6.]],
+                              [[7., 8., 9.], [10., 11., 12.]]])
+    indices = [[0, 0, 1], [1, 1, 2]]
+    updates = [100., 200.]
+
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      self.ops.scatter_add_ndim(input_data, indices, updates).run()
+      self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
+                           [[7., 8., 9.], [10., 11., 212.]]], input_data.eval())
+
+  def testNoUpdates(self):
+    init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
+    input_data = tf.Variable(init_val)
+    indices = []
+    updates = []
+
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      self.ops.scatter_add_ndim(input_data, indices, updates).run()
+      self.assertAllEqual(init_val, input_data.eval())
+
+  def testBadInput(self):
+    init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
+    input_data = tf.Variable(init_val)
+    indices = [[0, 0, 1], [1, 1, 2]]
+    updates = [100.]
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      with self.assertRaisesOpError(
+          'Number of updates should be same as number of indices.'):
+        self.ops.scatter_add_ndim(input_data, indices, updates).run()
+        self.assertAllEqual(init_val, input_data.eval())
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
new file mode 100644
index 0000000..743cd83
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
@@ -0,0 +1,103 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.tree_predictions_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow  # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import inference_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class TreePredictionsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.ops = inference_ops.Load()
+
+  def testSimple(self):
+    input_data = [[-1., 0.], [-1., 2.],  # node 1
+                  [1., 0.], [1., -2.]]  # node 2
+
+    tree = [[1, 0], [-1, 0], [-1, 0]]
+    tree_thresholds = [0., 0., 0.]
+    node_pcw = [[0.3, 0.4, 0.3], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25]]
+
+    with self.test_session():
+      predictions = self.ops.tree_predictions(
+          input_data, tree, tree_thresholds, node_pcw,
+          valid_leaf_threshold=1)
+
+      self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8],
+                           [0.5, 0.25, 0.25], [0.5, 0.25, 0.25]],
+                          predictions.eval())
+
+  def testBackoffToParent(self):
+    input_data = [[-1., 0.], [-1., 2.],  # node 1
+                  [1., 0.], [1., -2.]]  # node 2
+
+    tree = [[1, 0], [-1, 0], [-1, 0]]
+    tree_thresholds = [0., 0., 0.]
+    node_pcw = [[3.0, 9.0, 3.0], [1.0, 1.0, 3.0], [5.0, 20.0, 0.0]]
+
+    with self.test_session():
+      predictions = self.ops.tree_predictions(
+          input_data, tree, tree_thresholds, node_pcw,
+          valid_leaf_threshold=10)
+
+      # Node 2 has enough data, but Node 1 needs to combine with the parent
+      # counts.
+      self.assertAllClose([[0.2, 0.4, 0.4], [0.2, 0.4, 0.4],
+                           [0.2, 0.8, 0.0], [0.2, 0.8, 0.0]],
+                          predictions.eval())
+
+  def testNoInput(self):
+    input_data = []
+
+    tree = [[1, 0], [-1, 0], [-1, 0]]
+    tree_thresholds = [0., 0., 0.]
+    node_pcw = [[0.3, 0.4, 0.3], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25]]
+
+    with self.test_session():
+      predictions = self.ops.tree_predictions(
+          input_data, tree, tree_thresholds, node_pcw,
+          valid_leaf_threshold=10)
+
+      self.assertEquals((0, 3), predictions.eval().shape)
+
+  def testBadInput(self):
+    input_data = [[-1., 0.], [-1., 2.],  # node 1
+                  [1., 0.], [1., -2.]]  # node 2
+
+    tree = [[1, 0], [-1, 0], [-1, 0]]
+    tree_thresholds = [0., 0.]  # not enough nodes.
+    node_pcw = [[0.3, 0.4, 0.3], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25]]
+
+    with self.test_session():
+      with self.assertRaisesOpError(
+          'Number of nodes should be the same in tree, tree_thresholds '
+          'and node_pcw.'):
+        predictions = self.ops.tree_predictions(
+            input_data, tree, tree_thresholds, node_pcw,
+            valid_leaf_threshold=10)
+
+        self.assertEquals((0, 3), predictions.eval().shape)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
new file mode 100644
index 0000000..36c232f
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
@@ -0,0 +1,102 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.allocate_deallocate_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow  # pylint: disable=unused-import
+
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    # tree is:
+    #         0
+    #     1       2
+    #   3   4   5   6
+    self.finished = [2]
+    self.depths = [1, 2, 2, 3, 3, 3, 3]
+    self.non_fertile_leaves = [3, 4]
+    self.non_fertile_leaf_scores = [10., 15.]
+    self.end_of_tree = [5]
+    self.node_map = [-1, -1, 0, -1, -1, -1, -1]
+    self.candidate_counts = [[[10., 20.], [30., 10.]]]
+    self.total_counts = [[40., 40.]]
+    self.ops = training_ops.Load()
+
+  def testSimple(self):
+    with self.test_session():
+      (node_map_updates, accumulators_cleared, accumulators_allocated,
+       new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+           self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
+           self.end_of_tree, self.depths, self.candidate_counts,
+           self.total_counts, self.node_map, max_depth=4)
+
+      self.assertAllEqual([[2, 4], [-1, 0]], node_map_updates.eval())
+      self.assertAllEqual([], accumulators_cleared.eval())
+      self.assertAllEqual([0], accumulators_allocated.eval())
+      self.assertAllEqual([3, 5, 6], new_nfl.eval())
+      self.assertAllEqual([10., 1., 1.], new_nfl_scores.eval())
+
+  def testReachedMaxDepth(self):
+    with self.test_session():
+      (node_map_updates, accumulators_cleared, accumulators_allocated,
+       new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+           self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
+           self.end_of_tree, self.depths, self.candidate_counts,
+           self.total_counts, self.node_map, max_depth=3)
+
+      self.assertAllEqual([[2], [-1]], node_map_updates.eval())
+      self.assertAllEqual([0], accumulators_cleared.eval())
+      self.assertAllEqual([], accumulators_allocated.eval())
+      self.assertAllEqual([-1], new_nfl.eval())
+      self.assertAllEqual([0.0], new_nfl_scores.eval())
+
+  def testNoFinished(self):
+    with self.test_session():
+      (node_map_updates, accumulators_cleared, accumulators_allocated,
+       new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+           [], self.non_fertile_leaves, self.non_fertile_leaf_scores,
+           self.end_of_tree, self.depths, self.candidate_counts,
+           self.total_counts, self.node_map, max_depth=4)
+
+      self.assertAllEqual((2, 0), node_map_updates.eval().shape)
+      self.assertAllEqual([], accumulators_cleared.eval())
+      self.assertAllEqual([], accumulators_allocated.eval())
+      self.assertAllEqual([4, 3], new_nfl.eval())
+      self.assertAllEqual([15., 10.], new_nfl_scores.eval())
+
+  def testBadInput(self):
+    del self.non_fertile_leaf_scores[-1]
+    with self.test_session():
+      with self.assertRaisesOpError(
+          'Number of non fertile leaves should be the same in '
+          'non_fertile_leaves and non_fertile_leaf_scores.'):
+        (node_map_updates, _, _, _, _) = self.ops.update_fertile_slots(
+            self.finished, self.non_fertile_leaves,
+            self.non_fertile_leaf_scores, self.end_of_tree, self.depths,
+            self.candidate_counts, self.total_counts,
+            self.node_map, max_depth=4)
+        self.assertAllEqual((2, 0), node_map_updates.eval().shape)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
new file mode 100644
index 0000000..7cad6a8
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
@@ -0,0 +1,63 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for BrainTree v2 tree evaluation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+
+import tensorflow as tf
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+INFERENCE_OPS_FILE = '_inference_ops.so'
+
+_inference_ops = None
+_ops_lock = threading.Lock()
+
+
+ops.NoGradient('TreePredictions')
+
+
+@ops.RegisterShape('TreePredictions')
+def TreePredictions(op):
+  """Shape function for TreePredictions Op."""
+  num_points = op.inputs[0].get_shape()[0].value
+  num_classes = op.inputs[3].get_shape()[1].value
+  # The output of TreePredictions is
+  # [node_pcw(evaluate_tree(x), c) for c in classes for x in input_data].
+  return [tensor_shape.TensorShape([num_points, num_classes])]
+
+
+# Workaround for the fact that importing tensorflow imports contrib
+# (even if a user isn't using this or any other contrib op), but
+# there's not yet any guarantee that the shared object exists.
+# In which case, "import tensorflow" will always crash, even for users that
+# never use contrib.
+def Load():
+  """Load the inference ops library and return the loaded module."""
+  with _ops_lock:
+    global _inference_ops
+    if not _inference_ops:
+      data_files_path = tf.resource_loader.get_data_files_path()
+      tf.logging.info('data path: %s', data_files_path)
+      _inference_ops = tf.load_op_library(os.path.join(
+          data_files_path, INFERENCE_OPS_FILE))
+
+      assert _inference_ops, 'Could not load inference_ops.so'
+  return _inference_ops
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
new file mode 100644
index 0000000..8ca2491
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -0,0 +1,110 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for BrainTree v2 training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+
+import tensorflow as tf
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+TRAINING_OPS_FILE = '_training_ops.so'
+
+_training_ops = None
+_ops_lock = threading.Lock()
+
+ops.NoGradient('CountExtremelyRandomStats')
+ops.NoGradient('SampleInputs')
+ops.NoGradient('BestSplits')
+ops.NoGradient('GrowTree')
+ops.NoGradient('FinishedNodes')
+ops.NoGradient('ScatterAddNdim')
+ops.NoGradient('UpdateFertileSlots')
+
+
+@ops.RegisterShape('CountExtremelyRandomStats')
+def _CountExtremelyRandomStatsShape(op):
+  """Shape function for CountExtremelyRandomStats Op."""
+  num_points = op.inputs[0].get_shape()[0].value
+  num_nodes = op.inputs[2].get_shape()[0].value
+  num_classes = op.get_attr('num_classes')
+  # The output of TraverseTree is [leaf_node_index(x) for x in input_data].
+  return [tensor_shape.TensorShape([num_nodes, num_classes]),  # node pcw
+          tensor_shape.TensorShape([None, 3]),
+          tensor_shape.TensorShape([None]),
+          tensor_shape.TensorShape([None, 2]),
+          tensor_shape.TensorShape([None]),
+          tensor_shape.TensorShape([num_points])]
+
+
+@ops.RegisterShape('SampleInputs')
+def _SampleInputsShape(op):
+  """Shape function for SampleInputs Op."""
+  num_splits = op.inputs[3].get_shape()[1].value
+  return [[None], [None, num_splits], [None, num_splits]]
+
+
+@ops.RegisterShape('BestSplits')
+def _BestSplitsShape(op):
+  num_finished = op.inputs[0].get_shape()[0].value
+  return [tensor_shape.TensorShape([num_finished])]
+
+
+@ops.RegisterShape('GrowTree')
+def _GrowTreeShape(unused_op):
+  """Shape function for GrowTree Op."""
+  return [[None], [None, 2], [None], [None], [1]]
+
+
+@ops.RegisterShape('FinishedNodes')
+def _FinishedNodesShape(unused_op):
+  """Shape function for FinishedNodes Op."""
+  return [[None]]
+
+
+@ops.RegisterShape('ScatterAddNdim')
+def _ScatterAddNdimShape(unused_op):
+  """Shape function for ScatterAddNdim Op."""
+  return []
+
+
+@ops.RegisterShape('UpdateFertileSlots')
+def _UpdateFertileSlotsShape(unused_op):
+  """Shape function for UpdateFertileSlots Op."""
+  return [[None, 2], [None], [None], [None], [None]]
+
+
+# Workaround for the fact that importing tensorflow imports contrib
+# (even if a user isn't using this or any other contrib op), but
+# there's not yet any guarantee that the shared object exists.
+# In which case, "import tensorflow" will always crash, even for users that
+# never use contrib.
+def Load():
+  """Load training ops library and return the loaded module."""
+  with _ops_lock:
+    global _training_ops
+    if not _training_ops:
+      data_files_path = tf.resource_loader.get_data_files_path()
+      tf.logging.info('data path: %s', data_files_path)
+      _training_ops = tf.load_op_library(os.path.join(
+          data_files_path, TRAINING_OPS_FILE))
+
+      assert _training_ops, 'Could not load _training_ops.so'
+  return _training_ops
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
new file mode 100644
index 0000000..27b2b6d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -0,0 +1,537 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extremely random forest graph builder. go/brain-tree."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python.ops import inference_ops
+from tensorflow.contrib.tensor_forest.python.ops import training_ops
+
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+
+# Default parameter values.  These are all only used if the corresponding
+# parameter is not specified when constructing the ForestHParams.
+flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest')
+flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.')
+flags.DEFINE_float(
+    'samples_to_decide', 25.0,
+    'Only decide on a split, or only fully use a leaf, after this many '
+    'training samples have been seen.')
+
+# If tree[i][0] equals this value, then i is a leaf node.
+LEAF_NODE = -1
+
+
+# A convenience class for holding random forest hyperparameters.
+#
+# To just get some good default parameters, use:
+#   hparams = ForestHParams(num_classes=2, num_features=40).fill()
+#
+# Note that num_classes can not be inferred and so must always be specified.
+# Also, either num_splits_to_consider or num_features should be set.
+#
+# To override specific values, pass them to the constructor:
+#   hparams = ForestHParams(num_classes=5, num_trees=10, num_features=5).fill()
+#
+# TODO(thomaswc): Inherit from tf.HParams when that is publicly available.
+class ForestHParams(object):
+  """A base class for holding hyperparameters and calculating good defaults."""
+
+  def __init__(self, **kwargs):
+    for name, value in kwargs.iteritems():
+      setattr(self, name, value)
+
+  def values(self):
+    return self.__dict__
+
+  def fill(self):
+    """Intelligently sets any non-specific parameters."""
+    # Fail fast if num_classes isn't set.
+    _ = getattr(self, 'num_classes')
+
+    self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
+    self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
+
+    # Allow each tree to be unbalanced by up to a factor of 2.
+    self.max_depth = getattr(self, 'max_depth',
+                             int(2 * math.ceil(math.log(self.max_nodes, 2))))
+
+    # The Random Forest literature recommends sqrt(# features) for
+    # classification problems, and p/3 for regression problems.
+    # TODO(thomaswc): Consider capping this for large number of features.
+    if not getattr(self, 'num_splits_to_consider', None):
+      self.num_splits_to_consider = max(10, int(
+          math.ceil(math.sqrt(self.num_features))))
+
+    # max_fertile_nodes doesn't effect performance, only training speed.
+    # We therefore set it primarily based upon space considerations.
+    # Each fertile node takes up num_splits_to_consider times as much
+    # as space as a non-fertile node.  We want the fertile nodes to in
+    # total only take up as much space as the non-fertile nodes, so
+    num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider))
+    # But always use at least 1000 accumulate slots.
+    num_fertile = max(num_fertile, 1000)
+    self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
+    # But it also never needs to be larger than the number of leaves,
+    # which is max_nodes / 2.
+    self.max_fertile_nodes = min(self.max_nodes,
+                                 int(math.ceil(self.max_fertile_nodes / 2.0)))
+
+    # split_after_samples and valid_leaf_threshold should be about the same.
+    # Therefore, if either is set, use it to set the other.  Otherwise, fall
+    # back on FLAGS.samples_to_decide.
+    samples_to_decide = (
+        getattr(self, 'split_after_samples',
+                getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide)))
+    self.split_after_samples = getattr(self, 'split_after_samples',
+                                       samples_to_decide)
+    self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold',
+                                        samples_to_decide)
+
+    # We have num_splits_to_consider slots to fill, and we want to spend
+    # approximately split_after_samples samples initializing them.
+    num_split_initializiations_per_input = max(1, int(math.floor(
+        self.num_splits_to_consider / self.split_after_samples)))
+    self.split_initializations_per_input = getattr(
+        self, 'split_initializations_per_input',
+        num_split_initializiations_per_input)
+
+    # If base_random_seed is 0, the current time will be used to seed the
+    # random number generators for each tree.  If non-zero, the i-th tree
+    # will be seeded with base_random_seed + i.
+    self.base_random_seed = getattr(self, 'base_random_seed', 0)
+
+    return self
+
+
+# A simple container to hold the training variables for a single tree.
+class TreeTrainingVariables(object):
+
+  def __init__(self, params):
+    self.tree = tf.Variable(
+        [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1), name='tree')
+    self.tree_thresholds = tf.Variable(
+        [-1.0] * (params.max_nodes), name='tree_thresholds')
+    self.tree_depths = tf.Variable(
+        [1] * (params.max_nodes), name='tree_depths')
+    self.end_of_tree = tf.Variable([1], name='end_of_tree')
+
+    self.non_fertile_leaves = tf.Variable([0], name='non_fertile_leaves')
+    self.non_fertile_leaf_scores = tf.Variable(
+        [1.0], name='non_fertile_leaf_scores')
+
+    self.node_to_accumulator_map = tf.Variable(
+        [-1] * params.max_nodes, name='node_to_accumulator_map')
+
+    self.candidate_split_features = tf.Variable(
+        [[-1] * params.num_splits_to_consider] * params.max_fertile_nodes,
+        name='candidate_split_features')
+    self.candidate_split_thresholds = tf.Variable(
+        [[0.0] * params.num_splits_to_consider] * params.max_fertile_nodes,
+        name='candidate_split_thresholds')
+
+    self.node_per_class_weights = tf.Variable(
+        [[0.0] * params.num_classes] * params.max_nodes,
+        name='node_per_class_weights')
+    self.candidate_split_per_class_weights = tf.Variable(
+        [[[0.0] * params.num_classes] * params.num_splits_to_consider] *
+        params.max_fertile_nodes,
+        name='candidate_split_per_class_weights')
+    self.total_split_per_class_weights = tf.Variable(
+        [[-1.0] * params.num_classes] * params.max_fertile_nodes,
+        name='total_split_per_class_weights')
+
+
+class ForestStats(object):
+
+  def __init__(self, tree_stats, params):
+    """A simple container for stats about a forest."""
+    self.tree_stats = tree_stats
+    self.params = params
+
+  def get_average(self, thing):
+    val = 0.0
+    for i in range(self.params.num_trees):
+      val += getattr(self.tree_stats[i], thing)
+
+    return val / self.params.num_trees
+
+
+class TreeStats(object):
+
+  def __init__(self, num_nodes, num_leaves):
+    self.num_nodes = num_nodes
+    self.num_leaves = num_leaves
+
+
+def get_tree_stats(variables, unused_params, session):
+  num_nodes = variables.end_of_tree.eval(session=session) - 1
+  num_leaves = tf.where(
+      tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])),
+               LEAF_NODE)).eval(session=session).shape[0]
+  return TreeStats(num_nodes, num_leaves)
+
+
+def get_forest_stats(variables, params, session):
+
+  tree_stats = []
+  for i in range(params.num_trees):
+    tree_stats.append(get_tree_stats(variables[i], params, session))
+
+  return ForestStats(tree_stats, params)
+
+
+class ForestTrainingVariables(object):
+  """A container for a forests training data, consisting of multiple trees.
+
+  Instantiates a TreeTrainingVariables object for each tree. We override the
+  __getitem__ and __setitem__ function so that usage looks like this:
+
+    forest_variables = ForestTrainingVariables(params)
+
+    ... forest_variables.tree ...
+  """
+
+  def __init__(self, params):
+    self.variables = [TreeTrainingVariables(params)
+                      for _ in range(params.num_trees)]
+
+  def __setitem__(self, t, val):
+    self.variables[t] = val
+
+  def __getitem__(self, t):
+    return self.variables[t]
+
+
+class RandomForestGraphs(object):
+  """Builds TF graphs for random forest training and inference."""
+
+  def __init__(self, params):
+    self.params = params
+    self.variables = ForestTrainingVariables(self.params)
+    self.trees = [RandomTreeGraphs(self.variables[i], self.params,
+                                   training_ops.Load(), inference_ops.Load())
+                  for i in range(self.params.num_trees)]
+
+  def training_graph(self, input_data, input_labels):
+    """Constructs a TF graph for training a random forest.
+
+    Args:
+      input_data: A tensor or placeholder for input data.
+      input_labels: A tensor or placeholder for labels associated with
+        input_data.
+
+    Returns:
+      The last op in the random forest training graph.
+    """
+    tree_graphs = []
+    for i in range(self.params.num_trees):
+      tf.logging.info('Constructing tree %d', i)
+      seed = self.params.base_random_seed
+      if seed != 0:
+        seed += i
+      tree_graphs.append(self.trees[i].training_graph(
+          input_data, input_labels, seed))
+    return tf.group(*tree_graphs)
+
+  def inference_graph(self, input_data):
+    """Constructs a TF graph for evaluating a random forest.
+
+    Args:
+      input_data: A tensor or placeholder for input data.
+
+    Returns:
+      The last op in the random forest inference graph.
+    """
+    probabilities = []
+    for i in range(self.params.num_trees):
+      probabilities.append(self.trees[i].inference_graph(input_data))
+    all_predict = tf.pack(probabilities)
+    return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+
+  def average_impurity(self):
+    """Constructs a TF graph for evaluating the leaf impurity of a forest.
+
+    Returns:
+      The last op in the graph.
+    """
+    impurities = []
+    for i in range(self.params.num_trees):
+      impurities.append(self.trees[i].average_impurity(self.variables[i]))
+    return tf.reduce_mean(tf.pack(impurities))
+
+
+class RandomTreeGraphs(object):
+  """Builds TF graphs for random tree training and inference."""
+
+  def __init__(self, variables, params, t_ops, i_ops):
+    self.training_ops = t_ops
+    self.inference_ops = i_ops
+    self.variables = variables
+    self.params = params
+
+  def _gini(self, class_counts):
+    """Calculate the Gini impurity.
+
+    If c(i) denotes the i-th class count and c = sum_i c(i) then
+      score = 1 - sum_i ( c(i) / c )^2
+
+    Args:
+      class_counts: A 2-D tensor of per-class counts, from either
+        candidate_split_per_class_weights or total_split_per_class_weights.
+
+    Returns:
+      A 1-D tensor of the Gini impurities for each row in the input.
+    """
+    smoothed = 1.0 + class_counts
+    sums = tf.reduce_sum(smoothed, 1)
+    sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+
+    return 1.0 - sum_squares / (sums * sums)
+
+  def _weighted_gini(self, class_counts):
+    """Our split score is the Gini impurity times the number of examples.
+
+    If c(i) denotes the i-th class count and c = sum_i c(i) then
+      score = c * (1 - sum_i ( c(i) / c )^2 )
+            = c - sum_i c(i)^2 / c
+    Args:
+      class_counts: A 2-D tensor of per-class counts, from either
+        candidate_split_per_class_weights or total_split_per_class_weights.
+
+    Returns:
+      A 1-D tensor of the Gini impurities for each row in the input.
+    """
+    smoothed = 1.0 + class_counts
+    sums = tf.reduce_sum(smoothed, 1)
+    sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+
+    return sums - sum_squares / sums
+
+  def training_graph(self, input_data, input_labels, random_seed):
+    """Constructs a TF graph for training a random tree.
+
+    Args:
+      input_data: A tensor or placeholder for input data.
+      input_labels: A tensor or placeholder for labels associated with
+        input_data.
+      random_seed: The random number generator seed to use for this tree.  0
+        means use the current time as the seed.
+
+    Returns:
+      The last op in the random tree training graph.
+    """
+    # Count extremely random stats.
+    (pcw_node_delta, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
+     pcw_totals_delta, input_leaves) = (
+         self.training_ops.count_extremely_random_stats(
+             input_data, input_labels, self.variables.tree,
+             self.variables.tree_thresholds,
+             self.variables.node_to_accumulator_map,
+             self.variables.candidate_split_features,
+             self.variables.candidate_split_thresholds,
+             num_classes=self.params.num_classes))
+    node_update_op = tf.assign_add(self.variables.node_per_class_weights,
+                                   pcw_node_delta)
+    candidate_update_op = self.training_ops.scatter_add_ndim(
+        self.variables.candidate_split_per_class_weights,
+        pcw_splits_indices, pcw_splits_delta)
+
+    totals_update_op = self.training_ops.scatter_add_ndim(
+        self.variables.total_split_per_class_weights, pcw_totals_indices,
+        pcw_totals_delta)
+
+    # Sample inputs.
+    update_indices, feature_updates, threshold_updates = (
+        self.training_ops.sample_inputs(
+            input_data, self.variables.node_to_accumulator_map,
+            input_leaves, self.variables.candidate_split_features,
+            self.variables.candidate_split_thresholds,
+            split_initializations_per_input=(
+                self.params.split_initializations_per_input),
+            split_sampling_random_seed=random_seed))
+    update_features_op = tf.scatter_update(
+        self.variables.candidate_split_features, update_indices,
+        feature_updates)
+    update_thresholds_op = tf.scatter_update(
+        self.variables.candidate_split_thresholds, update_indices,
+        threshold_updates)
+
+    # Calculate finished nodes.
+    with tf.control_dependencies([totals_update_op]):
+      children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
+                            squeeze_dims=[1])
+      is_leaf = tf.equal(LEAF_NODE, children)
+      leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
+      finished = self.training_ops.finished_nodes(
+          leaves, self.variables.node_to_accumulator_map,
+          self.variables.total_split_per_class_weights,
+          num_split_after_samples=self.params.split_after_samples)
+
+    # Update leaf scores.
+    # TODO(gilberth): Optimize this. It currently calculates counts for
+    # every non-fertile leaf.
+    with tf.control_dependencies([node_update_op]):
+      def f1():
+        return self.variables.non_fertile_leaf_scores
+      def f2():
+        counts = tf.gather(self.variables.node_per_class_weights,
+                           self.variables.non_fertile_leaves)
+        new_scores = self._weighted_gini(counts)
+        return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
+
+      # Because we can't have tf.self.variables of size 0, we have to put in a
+      # garbage value of -1 in there.  Here we check for that so we don't
+      # try to index into node_per_class_weights in a tf.gather with a negative
+      # number.
+      update_nonfertile_leaves_scores_op = tf.cond(tf.less(
+          self.variables.non_fertile_leaves[0], 0), f1, f2)
+
+    # Calculate best splits.
+    with tf.control_dependencies([candidate_update_op, totals_update_op]):
+      split_indices = self.training_ops.best_splits(
+          finished, self.variables.node_to_accumulator_map,
+          self.variables.candidate_split_per_class_weights,
+          self.variables.total_split_per_class_weights)
+
+    # Grow tree.
+    with tf.control_dependencies([update_features_op, update_thresholds_op]):
+      (tree_update_indices, tree_children_updates,
+       tree_threshold_updates, tree_depth_updates, new_eot) = (
+           self.training_ops.grow_tree(
+               self.variables.end_of_tree, self.variables.tree_depths,
+               self.variables.node_to_accumulator_map, finished, split_indices,
+               self.variables.candidate_split_features,
+               self.variables.candidate_split_thresholds))
+      tree_update_op = tf.scatter_update(
+          self.variables.tree, tree_update_indices, tree_children_updates)
+      threhsolds_update_op = tf.scatter_update(
+          self.variables.tree_thresholds, tree_update_indices,
+          tree_threshold_updates)
+      depth_update_op = tf.scatter_update(
+          self.variables.tree_depths, tree_update_indices, tree_depth_updates)
+
+    # Update fertile slots.
+    with tf.control_dependencies([update_nonfertile_leaves_scores_op,
+                                  depth_update_op]):
+      (node_map_updates, accumulators_cleared, accumulators_allocated,
+       new_nonfertile_leaves, new_nonfertile_leaves_scores) = (
+           self.training_ops.update_fertile_slots(
+               finished, self.variables.non_fertile_leaves,
+               self.variables.non_fertile_leaf_scores,
+               self.variables.end_of_tree, self.variables.tree_depths,
+               self.variables.candidate_split_per_class_weights,
+               self.variables.total_split_per_class_weights,
+               self.variables.node_to_accumulator_map,
+               max_depth=self.params.max_depth))
+
+    # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
+    # used it to calculate new leaves.
+    gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves])
+    eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot)
+
+    updates = []
+    updates.append(eot_update_op)
+    updates.append(tree_update_op)
+    updates.append(threhsolds_update_op)
+    updates.append(tf.assign(
+        self.variables.non_fertile_leaves, new_nonfertile_leaves,
+        validate_shape=False))
+    updates.append(tf.assign(
+        self.variables.non_fertile_leaf_scores,
+        new_nonfertile_leaves_scores, validate_shape=False))
+
+    updates.append(tf.scatter_update(
+        self.variables.node_to_accumulator_map,
+        tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]),
+                   squeeze_dims=[0]),
+        tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]),
+                   squeeze_dims=[0])))
+
+    cleared_and_allocated_accumulators = tf.concat(
+        0, [accumulators_cleared, accumulators_allocated])
+    # Calculate values to put into scatter update for candidate counts.
+    # Candidate split counts are always reset back to 0 for both cleared
+    # and allocated accumulators. This means some accumulators might be doubly
+    # reset to 0 if the were released and not allocated, then later allocated.
+    candidate_pcw_values = tf.tile(
+        tf.expand_dims(tf.expand_dims(
+            tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32),
+            1), 2),
+        [1, self.params.num_splits_to_consider, self.params.num_classes])
+    updates.append(tf.scatter_update(
+        self.variables.candidate_split_per_class_weights,
+        cleared_and_allocated_accumulators, candidate_pcw_values))
+
+    # Calculate values to put into scatter update for total counts.
+    total_cleared = tf.tile(
+        tf.expand_dims(
+            tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1),
+        [1, self.params.num_classes])
+    total_reset = tf.tile(
+        tf.expand_dims(
+            tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1),
+        [1, self.params.num_classes])
+    total_pcw_updates = tf.concat(0, [total_cleared, total_reset])
+    updates.append(tf.scatter_update(
+        self.variables.total_split_per_class_weights,
+        cleared_and_allocated_accumulators, total_pcw_updates))
+
+    # Calculate values to put into scatter update for candidate splits.
+    split_features_updates = tf.tile(
+        tf.expand_dims(
+            tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1),
+        [1, self.params.num_splits_to_consider])
+    updates.append(tf.scatter_update(
+        self.variables.candidate_split_features,
+        cleared_and_allocated_accumulators, split_features_updates))
+
+    return tf.group(*updates)
+
+  def inference_graph(self, input_data):
+    """Constructs a TF graph for evaluating a random tree.
+
+    Args:
+      input_data: A tensor or placeholder for input data.
+
+    Returns:
+      The last op in the random tree inference graph.
+    """
+    return self.inference_ops.tree_predictions(
+        input_data, self.variables.tree, self.variables.tree_thresholds,
+        self.variables.node_per_class_weights,
+        valid_leaf_threshold=self.params.valid_leaf_threshold)
+
+  def average_impurity(self):
+    """Constructs a TF graph for evaluating the average leaf impurity of a tree.
+
+    Returns:
+      The last op in the graph.
+    """
+    children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
+                          squeeze_dims=[1])
+    is_leaf = tf.equal(LEAF_NODE, children)
+    leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
+    counts = tf.gather(self.variables.node_per_class_weights, leaves)
+    impurity = self._weighted_gini(counts)
+    return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
new file mode 100644
index 0000000..e4846cb
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -0,0 +1,55 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.ops.tensor_forest."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.python import tensor_forest
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class TensorForestTest(test_util.TensorFlowTestCase):
+
+  def testTrainingConstruction(self):
+    input_data = [[-1., 0.], [-1., 2.],  # node 1
+                  [1., 0.], [1., -2.]]  # node 2
+    input_labels = [0, 1, 2, 3]
+
+    params = tensor_forest.ForestHParams(
+        num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
+
+    graph_builder = tensor_forest.RandomForestGraphs(params)
+    graph = graph_builder.training_graph(input_data, input_labels)
+    self.assertTrue(isinstance(graph, tf.Operation))
+
+  def testInferenceConstruction(self):
+    input_data = [[-1., 0.], [-1., 2.],  # node 1
+                  [1., 0.], [1., -2.]]  # node 2
+
+    params = tensor_forest.ForestHParams(
+        num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
+
+    graph_builder = tensor_forest.RandomForestGraphs(params)
+    graph = graph_builder.inference_graph(input_data)
+    self.assertTrue(isinstance(graph, tf.Tensor))
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/tools/ci_build/builds/test_installation.sh b/tensorflow/tools/ci_build/builds/test_installation.sh
index b6385ac..2b87ac3 100755
--- a/tensorflow/tools/ci_build/builds/test_installation.sh
+++ b/tensorflow/tools/ci_build/builds/test_installation.sh
@@ -76,7 +76,8 @@
 
 # Test blacklist: GPU-only
 PY_TEST_GPU_BLACKLIST="${PY_TEST_GPU_BLACKLIST}:"\
-"tensorflow/python/framework/function_test.py"
+"tensorflow/python/framework/function_test.py:"\
+"tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py"
 
 # Tests that should be run in the exclusive mode (i.e., not parallel with
 # other tests)