Fix control dependency problems and add corresponding tests.
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 033d520..a1071d6 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -85,11 +85,12 @@
     copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = [
+        ":test_utils",
         ":trt_allocator",
+        ":trt_conversion",
         ":trt_logging",
         ":trt_plugins",
         ":trt_resources",
-        ":trt_conversion",
         ":utils",
         "//tensorflow/core:gpu_headers_lib",
         "//tensorflow/core:lib_proto_parsing",
@@ -192,6 +193,7 @@
         "//tensorflow/python:platform/base.i",
     ],
     deps = [
+        ":test_utils",
         ":trt_conversion",
         ":trt_engine_op_kernel",
         "//third_party/python_runtime:headers",
@@ -264,6 +266,7 @@
     ],
     deps = [
         ":segment",
+        ":test_utils",
         ":trt_allocator",
         ":trt_plugins",
         ":trt_logging",
@@ -412,3 +415,12 @@
     hdrs = ["convert/utils.h"],
     copts = tf_copts(),
 )
+
+cc_library(
+    name = "test_utils",
+    srcs = ["test/utils.cc"],
+    hdrs = ["test/utils.h"],
+    deps = [
+        "//tensorflow/core:lib",
+    ],
+)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 22909a1..1e63005 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -20,6 +20,7 @@
 #include <map>
 #include <set>
 #include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -29,6 +30,7 @@
 #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
 #include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
 #include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
@@ -49,9 +51,9 @@
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/config.pb.h"             // NOLINT
+#include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
 #include "tensorflow/core/protobuf/device_properties.pb.h"  // NOLINT
-#include "tensorflow/core/protobuf/rewriter_config.pb.h"    // NOLINT
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"  // NOLINT
 #include "tensorflow/core/util/device_name_utils.h"
 
 #if GOOGLE_CUDA
@@ -260,63 +262,6 @@
   return ConvertAfterShapes(cp);
 }
 
-bool IsUniformTensorValue(const tensorflow::TensorProto& tensor) {
-  using tensorflow::DataType;
-  switch (tensor.dtype()) {
-    case DataType::DT_HALF:  // fall-through
-    case DataType::DT_BFLOAT16:
-      return tensor.half_val_size() == 1;
-    case DataType::DT_FLOAT:
-      return tensor.float_val_size() == 1;
-    case DataType::DT_DOUBLE:
-      return tensor.double_val_size() == 1;
-    case DataType::DT_INT32:  // fall-through
-    case DataType::DT_INT16:  // fall-through
-    case DataType::DT_INT8:   // fall-through
-    case DataType::DT_UINT8:
-      return tensor.int_val_size() == 1;
-    case DataType::DT_STRING:
-      return tensor.string_val_size() == 1;
-    case DataType::DT_COMPLEX64:
-      return tensor.scomplex_val_size() == 1;
-    case DataType::DT_INT64:
-      return tensor.int64_val_size() == 1;
-    case DataType::DT_BOOL:
-      return tensor.bool_val_size() == 1;
-    case DataType::DT_COMPLEX128:
-      return tensor.dcomplex_val_size() == 1;
-    case DataType::DT_RESOURCE:
-      return tensor.resource_handle_val_size() == 1;
-    case DataType::DT_VARIANT:
-      return tensor.variant_val_size() == 1;
-    case DataType::DT_UINT32:
-      return tensor.uint32_val_size() == 1;
-    case DataType::DT_UINT64:
-      return tensor.uint64_val_size() == 1;
-    default:
-      return false;
-  }
-}
-
-std::unordered_set<int> GetAttributeInputs(const tensorflow::Node* node) {
-  typedef std::unordered_map<string, std::unordered_set<int>> InputMap;
-  static const InputMap attribute_inputs = {
-      {"Concat", {0}}, {"ConcatV2", {-1}}, {"Reshape", {1}}};
-  auto iter = attribute_inputs.find(node->type_string());
-  if (iter != attribute_inputs.end()) {
-    // Apply reverse indexing
-    std::unordered_set<int> result;
-    for (int idx : iter->second) {
-      if (idx < 0) {
-        idx += node->num_inputs();
-      }
-      result.insert(idx);
-    }
-    return result;
-  }
-  return {};
-}
-
 // Function to get subsegment information structure.
 tensorflow::Status GetEngineInfo(
     const tensorflow::Graph* g,
@@ -325,13 +270,10 @@
     const std::unordered_map<string, tensorflow::Node*>& node_map,
     const std::vector<tensorflow::Node*>& reverse_topo_order,
     EngineInfo* info) {
-  std::vector<int> subgraph_node_ids;
+  std::vector<int> subgraph_node_ids;  // Topologically sorted node ids.
+  std::set<string> subgraph_node_names = segment_nodes;
   std::set<int> added_const_node_ids;  // Used to prevent double insertion.
   std::set<string> segment_devices;
-  std::unordered_set<string> segment_consts;
-  std::vector<int> const_node_ids;
-  int input_port = 0;
-  int output_port = 0;
 
   // Map from src_node_name+port to the unique port numbers of the TRT op, where
   // the src_node_name is the name of the source node of the input/output
@@ -339,7 +281,7 @@
   // input/output edges must be in different split of the graph.
   // TODO(aaroey): consider using node id and port instead.
   // TODO(aaroey): using topo order instead of reverting reverse topo order.
-  std::unordered_map<string, int> created_edges;
+  std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
   for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
        ++it) {
     const auto& node_name = (*it)->name();
@@ -358,133 +300,114 @@
     }
     const int node_id = node->id();
     subgraph_node_ids.push_back(node_id);
+    // Create input connections.
     for (const auto edge : node->in_edges()) {
       auto input_node = edge->src();
-      if (input_node->IsSource()) continue;
-      if (segment_nodes.count(input_node->name()) == 0) {
-        // Add constant input node into the segment. We don't care if it has
-        // other output edges going into other engines or TF nodes. Since we add
-        // it only to the subsegment node list, not the subsegment itself, it
-        // won't be removed from the graph. If it doesn't have any edges, TF
-        // will prune it out.
-        if (input_node->type_string() == "Const") {
-          bool is_supported = input_node->output_type(0) == DT_FLOAT ||
-                              input_node->output_type(0) == DT_INT32;
-          bool is_attribute_input =
-              GetAttributeInputs(node).count(edge->dst_input()) != 0;
-          const tensorflow::TensorProto& tensor_proto =
-              input_node->def().attr().at("value").tensor();
-          bool is_uniform = IsUniformTensorValue(tensor_proto);
-
-          // Const can be absorbed
-          if (is_supported && is_attribute_input && is_uniform) {
-            if (segment_consts.count(input_node->name()) != 0) {
-              continue;  // skip if already added
-            }
-            VLOG(0) << "Adding const node " << input_node->name();
-            const_node_ids.push_back(input_node->id());
-            segment_consts.insert(input_node->name());
-            int conn_count = 0;
-            for (auto cinp_e :
-                 input_node->in_edges()) {  // must be Control edges
-              if (!cinp_e->IsControlEdge() || cinp_e->src()->IsSource()) {
-                conn_count++;
-                continue;
-              }
-              VLOG(0) << info->engine_name << ": Control edge " << conn_count
-                      << " from node " << input_node->name()
-                      << " edge= " << cinp_e->src()->name();
-              auto cinp = cinp_e->src();
-              EngineConnection ec(cinp->name(), cinp->id(),
-                                  cinp_e->src_output(), input_node->name(),
-                                  input_node->id(), cinp_e->dst_input(), true,
-                                  -1, true);
-              info->connections.emplace_back(std::move(ec));
-            }
-            continue;
-          }
+      if (input_node->IsSource() || segment_nodes.count(input_node->name())) {
+        continue;
+      }
+      if (edge->IsControlEdge()) {
+        // Control input.
+        info->connections.emplace_back(
+            input_node->name(), input_node->id(), node_name, node_id,
+            /*input_edge=*/true);
+      } else if (input_node->type_string() == "Const") {
+        // Add constant data input nodes into the segment graphdef (thus also in
+        // the engine). We don't care if it has other output edges going into
+        // other engines or TF nodes. Since we add it only to the segment
+        // graphdef, not the segment itself, it won't be removed from the graph.
+        // If it doesn't have any edges, TF will prune it out.
+        // Note that the constant data input must be supported by the engine
+        // regardless of the datatype, since the segmenter already removed
+        // unsupported data input nodes.
+        if (!added_const_node_ids.insert(input_node->id()).second) {
+          // Already added before.
+          continue;
         }
-
-        // Non-const data/control edge
-        if (!edge->IsControlEdge()) {
-          string s(input_node->name());
-          StrAppend(&s, ":", edge->src_output());
-          VLOG(1) << "Input edge = " << s;
-          int port = input_port;
-          if (created_edges.count(s)) {
-            port = created_edges.at(s);
-          } else {
-            created_edges.insert({s, port});
-            input_port++;
-          }
-          EngineConnection ec(input_node->name(), input_node->id(),
-                              edge->src_output(), node_name, node_id,
-                              edge->dst_input(), true, port);
-          ec.connection_type = input_node->output_type(edge->src_output());
-          info->connections.emplace_back(std::move(ec));
+        VLOG(1) << "Adding const node " << input_node->name();
+        QCHECK(subgraph_node_names.insert(input_node->name()).second);
+#if 1
+        // Since we duplicate the const input node in both the segment graphdef
+        // and the engine, the segment node doesn't depend on it anymore, so we
+        // add a control dependency instead.
+        info->connections.emplace_back(
+            input_node->name(), input_node->id(), node_name, node_id,
+            /*input_edge=*/true);
+#else
+        // Add control inputs to the const node as control input connections to
+        // the engine.
+        for (const auto const_in_edge : input_node->in_edges()) {
+          QCHECK(const_in_edge->IsControlEdge());  // Must be control edge.
+          auto const_in_node = const_in_edge->src();
+          QCHECK(!segment_nodes.count(const_in_node->name()))
+              << "Loop found between segment and non-segment nodes, from "
+                 "segment node "
+              << const_in_node->name() << " to non-segment node "
+              << input_node->name() << " to segment node " << node->name();
+          if (const_in_node->IsSource()) continue;
+          VLOG(1) << "Control edge from node " << const_in_node->name()
+                  << " to " << input_node->name();
+          info->connections.emplace_back(
+              const_in_node->name(), const_in_node->id(), input_node->name(),
+              input_node->id(), /*input_edge=*/true);
+        }
+#endif
+      } else {
+        // Non-const data input.
+        int port = Graph::kControlSlot - 1;
+        // Use the source non-segment node name/port as key.
+        const string s = StrCat(input_node->name(), ":", edge->src_output());
+        VLOG(1) << "Input edge = " << s;
+        if (input_to_engine_port.count(s)) {
+          port = input_to_engine_port.at(s);
         } else {
-          EngineConnection ec(input_node->name(), input_node->id(),
-                              edge->src_output(), node_name, node_id,
-                              edge->dst_input(), true, -1, true);
-          ec.connection_type = input_node->output_type(edge->src_output());
-          info->connections.emplace_back(std::move(ec));
+          port = input_to_engine_port.size();
+          input_to_engine_port.insert({s, port});
         }
+        info->connections.emplace_back(input_node->name(), input_node->id(),
+                                       edge->src_output(), node_name, node_id,
+                                       edge->dst_input(), /*input_edge=*/true,
+                                       port);
       }
     }
-
+    // Create output connections.
     for (const auto edge : node->out_edges()) {
       auto output_node = edge->dst();
-      if (output_node->IsSink()) continue;
-      if (segment_nodes.count(output_node->name()) == 0) {
-        if (!edge->IsControlEdge()) {
-          string s(node_name);
-          StrAppend(&s, ":", edge->src_output());
-          VLOG(1) << "Output edge = " << s;
-          int port = output_port;
-          if (created_edges.count(s)) {
-            port = created_edges.at(s);
-          } else {
-            created_edges.insert({s, port});
-            output_port++;
-          }
-          info->connections.emplace_back(output_node->name(), output_node->id(),
-                                         edge->dst_input(), node_name, node_id,
-                                         edge->src_output(), false, port);
+      if (output_node->IsSink() || segment_nodes.count(output_node->name())) {
+        continue;
+      }
+      if (edge->IsControlEdge()) {
+        // Control output.
+        info->connections.emplace_back(
+            output_node->name(), output_node->id(), node_name, node_id,
+            /*input_edge=*/false);
+      } else {
+        // Data output.
+        int port = Graph::kControlSlot - 1;
+        // Use the source segment node name/port as key.
+        const string s = StrCat(node_name, ":", edge->src_output());
+        VLOG(1) << "Output edge = " << s;
+        if (output_to_engine_port.count(s)) {
+          port = output_to_engine_port.at(s);
         } else {
-          info->connections.emplace_back(output_node->name(), output_node->id(),
-                                         edge->dst_input(), node_name, node_id,
-                                         edge->src_output(), false, -1, true);
+          port = output_to_engine_port.size();
+          output_to_engine_port.insert({s, port});
         }
+        info->connections.emplace_back(output_node->name(), output_node->id(),
+                                       edge->dst_input(), node_name, node_id,
+                                       edge->src_output(), /*input_edge=*/false,
+                                       port);
       }
     }
-  }
+  }  // For each segment node in topological order.
 
-  // Fix control edges
-  for (size_t t = 0; t < info->connections.size(); t++) {
-    auto& conn = info->connections.at(t);
-    if (conn.is_control_edge) {
-      for (size_t k = 0; k < info->connections.size(); k++) {
-        if (k == t) continue;
-        const auto& other = info->connections.at(k);
-        if (conn.outside_id == other.outside_id && other.port_number != -1) {
-          VLOG(0) << "Updating control edge " << conn.outside_node_name
-                  << " -> " << conn.inside_node_name << " to input port "
-                  << other.port_number;
-          conn.port_number = other.port_number;
-          break;
-        }
-      }
-    }
-  }
-
-  // Construct the const nodes first
-  subgraph_node_ids.insert(subgraph_node_ids.begin(), const_node_ids.begin(),
-                           const_node_ids.end());
+  // Construct the const nodes first.
+  subgraph_node_ids.insert(subgraph_node_ids.begin(),
+                           added_const_node_ids.begin(),
+                           added_const_node_ids.end());
   TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
-      g, graph_properties, subgraph_node_ids, &info->connections,
-      &info->segment_graph_def, &info->engine_name));
-  info->engine_type = EngineInfo::EngineType::TRTStatic;
-
+      g, graph_properties, subgraph_node_names, subgraph_node_ids,
+      &info->connections, &info->segment_graph_def, &info->engine_name));
   // TODO(sami): This should not happen once segmenter is updated.
   if (segment_devices.size() == 1) {
     info->device = *segment_devices.begin();
@@ -502,36 +425,34 @@
 // Helper function to update edge connection from the removed node to the
 // engine node. If an outside node is gone, it must have been absorbed into
 // an engine node. Find the engine node.
-void UpdateToEngineNode(tensorflow::Node*& node, string& node_name, int& port,
-                        const std::vector<EngineInfo>& infos,
-                        size_t my_engine_id,
+void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
+                        const size_t my_engine_id,
                         const std::vector<Node*>& engine_nodes,
-                        bool update_input_edge) {
-  bool found_engine = false;
+                        const bool is_input_edge,
+                        const string& node_name,
+                        tensorflow::Node** node, int* port) {
   for (size_t t = 0; t < infos.size(); ++t) {
     if (t == my_engine_id) {
       continue;
     }
-    auto& connected_eng_info = infos.at(t);
-    for (const auto& eng_conn : connected_eng_info.connections) {
-      if (update_input_edge && eng_conn.is_input_edge) {
-        continue;
-      } else if (!update_input_edge && !eng_conn.is_input_edge) {
-        continue;
-      }
+    const auto& info = infos.at(t);
+    for (const auto& eng_conn : info.connections) {
+      // If the connection being updated is an input connection, the source of
+      // the connection must be an output connection of another engine. And vise
+      // versa.
+      if (is_input_edge == eng_conn.is_input_edge) continue;
       if (eng_conn.inside_node_name == node_name &&
-          eng_conn.inside_port == port) {
-        node = engine_nodes[t];
-        node_name = connected_eng_info.engine_name;
-        port = eng_conn.port_number;
-        found_engine = true;
-        break;
+          eng_conn.inside_port == *port) {
+        *node = CHECK_NOTNULL(engine_nodes[t]);
+        QCHECK_EQ(info.engine_name, (**node).name())
+            << "Engine name mismatch: " << info.engine_name << " vs "
+            << (**node).name();
+        *port = eng_conn.port_number;
+        return;
       }
     }
-    if (found_engine) break;
   }
-  CHECK(found_engine);
-  CHECK(node != nullptr);
+  LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
 }
 
 // Function to insert a TRT engine node into the graph.
@@ -539,114 +460,91 @@
 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
 // 2. When an engine node is created, add it into the graph with necessary
 //    re-wiring.
-//   2.1. If the outside connected node is existing, connect the engine
-//        node to it.
-//   2.2. If the outside connected node is gone, it must have been absorted
-//        into another engine node (which was processed before the processing
-//        one). Connect to the pre-existing engine node instead.
+//    2.1. If the outside connected node is existing, connect the engine
+//         node to it.
+//    2.2. If the outside connected node is gone, it must have been absorted
+//         into another engine node (which was processed before the processing
+//         one). Connect to the pre-existing engine node instead.
 // 3. In this way, we ensure the graph is topologically sort-able after each
 //    invocation of CreateTRTNode().
-
-tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
-                                 const std::vector<EngineInfo>& infos, int pos,
-                                 tensorflow::Allocator* alloc,
-                                 int max_batch_size,
-                                 std::vector<Node*>& engine_nodes) {
-  auto& info = infos.at(pos);
+tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
+                                 int max_batch_size, tensorflow::Graph* graph,
+                                 nvinfer1::IGpuAllocator* alloc,
+                                 std::vector<Node*>* engine_nodes) {
+  const auto& info = infos.at(pos);
+  TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail");
   std::vector<tensorflow::TensorShapeProto> output_shape_protos;
   std::vector<tensorflow::TensorShapeProto> input_shape_protos;
-  std::vector<tensorflow::PartialTensorShape> shapes;
+  std::vector<tensorflow::PartialTensorShape> input_shapes;
   std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
   std::vector<tensorflow::Node*> input_nodes;
   std::vector<tensorflow::Node*> control_input_nodes;
-  std::vector<string> control_input_names;
+  std::unordered_set<string> control_input_names;
   std::vector<tensorflow::DataType> out_types;
 
   VLOG(1) << "Processing " << info.engine_name;
-
-  // -- Preprocessing -- //
-  // collect needed info for creating the engine node in the graph
-  for (const auto conn : info.connections) {
-    // control edges
-    if (conn.is_control_edge) {
-      // skip control outputs for now. control output info are not needed for
+  // Collect needed info for creating the engine node in the graph
+  for (const auto& conn : info.connections) {
+    // Control edges
+    if (conn.is_control_edge()) {
+      // Skip control outputs for now. control output info are not needed for
       // node creation and will be processed later.
-      if (!conn.is_input_edge) {
-        continue;
-      }
+      if (!conn.is_input_edge) continue;
 
-      // control inputs
+      // Rewrire control input if it's not found in original graph.
       tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
-      string input_node_name = conn.outside_node_name;
       int port = tensorflow::Graph::kControlSlot;
       if (!input_node) {
-        UpdateToEngineNode(input_node, input_node_name, port, infos, pos,
-                           engine_nodes, true);
+        UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+                           conn.outside_node_name, &input_node, &port);
+        QCHECK_EQ(Graph::kControlSlot, port);
       }
-      bool new_input = true;
-      for (const auto& name : control_input_names) {
-        if (name == input_node_name) {
-          new_input = false;
-          break;
-        }
+      if (!control_input_names.insert(input_node->name()).second) {
+        continue;
       }
-      if (new_input) {
-        control_input_nodes.push_back(input_node);
-        control_input_names.push_back(input_node_name);
-
-        VLOG(1) << "Engine Control Input " << input_node_name << ":" << port
-                << " -> " << info.engine_name << ":"
-                << tensorflow::Graph::kControlSlot;
-      }
-
-      // data edges
+      control_input_nodes.push_back(input_node);
+      VLOG(1) << "Engine Control Input " << input_node->name()
+              << " -> " << info.engine_name;
     } else {
-      // data outputs
+      // Data edges
       if (!conn.is_input_edge) {
+        // Set the shapes and data types of output edge.
         tensorflow::TensorShapeProto out_shape;
-        conn.inside_shape.AsProto(
-            &out_shape);  // shape of the output node inside segment
+        // shape of the output node inside segment
+        conn.inside_shape.AsProto(&out_shape);
         if (output_shape_protos.size() <= conn.port_number) {
           output_shape_protos.resize(conn.port_number + 1);
           out_types.resize(conn.port_number + 1);
         }
         output_shape_protos.at(conn.port_number) = out_shape;
         out_types.at(conn.port_number) = conn.connection_type;
-
-        // data input
       } else {
+        // Set the shapes and data types of input edge.
         tensorflow::TensorShapeProto in_shape;
         conn.outside_shape.AsProto(&in_shape);
-
         if (input_shape_protos.size() <= conn.port_number) {
           input_shape_protos.resize(conn.port_number + 1);
-          shapes.resize(conn.port_number + 1);
+          input_shapes.resize(conn.port_number + 1);
         }
         input_shape_protos.at(conn.port_number) = in_shape;
-        shapes.at(conn.port_number) = conn.outside_shape;
+        input_shapes.at(conn.port_number) = conn.outside_shape;
 
+        // Rewrire data input if it's not found in original graph.
         tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
-        string input_node_name = conn.outside_node_name;
-        int input_port = conn.outside_port;
-        auto dtype = conn.connection_type;
-
+        int port = conn.outside_port;
         if (!input_node) {
-          UpdateToEngineNode(input_node, input_node_name, input_port, infos,
-                             pos, engine_nodes, true);
+          UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+                             conn.outside_node_name, &input_node, &port);
         }
-        bool new_input = true;
-        for (const auto& inp : inputs) {
-          if (inp.node == input_node_name && inp.index == input_port) {
-            new_input = false;
-            break;
-          }
-        }
-        if (new_input) {
-          inputs.emplace_back(input_node_name, input_port, dtype);
-          CHECK(input_node != nullptr);
-          input_nodes.push_back(input_node);
-
-          VLOG(1) << "Engine Input " << input_node_name << ":" << input_port
+        if (std::find_if(std::begin(inputs), std::end(inputs),
+                         [input_node, &port](
+                             const NodeDefBuilder::NodeOut& inp) {
+                           return inp.node == input_node->name() &&
+                               inp.index == port;
+                         }) == std::end(inputs)) {
+          inputs.emplace_back(input_node->name(), port, conn.connection_type);
+          input_nodes.push_back(CHECK_NOTNULL(input_node));
+          VLOG(1) << "Engine Input " << input_node->name() << ":" << port
                   << " -> " << info.engine_name << ":" << inputs.size() - 1;
         }
       }
@@ -662,14 +560,12 @@
     // Otherwise we skip node creation for this engine.
     Logger trt_logger;
     TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
-    std::unique_ptr<TRTDeviceAllocator> allocator(
-        new TRTDeviceAllocator(alloc));
     // TODO(sami): What happens if 1st dim is not batch?
     TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
         info.segment_graph_def,
         info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
-        max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
-        allocator.get(), /*calibrator=*/nullptr, &engine,
+        max_batch_size, info.max_workspace_size_bytes, input_shapes,
+        &trt_logger, alloc, /*calibrator=*/nullptr, &engine,
         /*convert_successfully=*/nullptr));
     TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
     segment_string =
@@ -711,7 +607,7 @@
     VLOG(1) << ins;
   }
   node_builder.Input(inputs);
-  for (auto& c : control_input_names) {
+  for (const string& c : control_input_names) {
     node_builder.ControlInput(c);
   }
 
@@ -744,54 +640,50 @@
   // Up until this point, graph is not modified. If we return !status.ok() from
   // here, this segment will be skipped
   tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
-  engine_nodes[pos] = engine_node;
+  (*engine_nodes)[pos] = engine_node;
   if (!status.ok()) {
     LOG(ERROR) << "Adding node failed " << status;
     return status;
   }
-  // input edges of the engine node
-  for (auto in : control_input_nodes) {
+  // Add control input and input edges to the engine node.
+  for (const auto in : control_input_nodes) {
     VLOG(1) << "Connecting control edge from " << in->name() << " to "
             << engine_node->name();
     graph->AddControlEdge(in, engine_node);
   }
-  int idx = 0;
   VLOG(1) << "input_nodes size = " << input_nodes.size();
-  for (auto in : inputs) {
-    Node* n = input_nodes[idx];
-    CHECK(n != nullptr);
+  for (int i = 0; i < input_nodes.size(); ++i) {
+    Node* n = input_nodes[i];
+    const auto& in = inputs[i];
+    CHECK_NOTNULL(n);
     VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
-            << " to " << engine_node->name() << ":" << idx;
-    graph->AddEdge(n, in.index, engine_node, idx++);
+            << " to " << engine_node->name() << ":" << i;
+    graph->AddEdge(n, in.index, engine_node, i);
   }
+
   // Updates the inputs of output edges destination nodes, and point them to the
   // engine node.
-
   for (auto& conn : info.connections) {
     if (conn.is_input_edge) {
       continue;
     }
-
-    string out_name = conn.outside_node_name;
-    auto out_node = graph->FindNodeId(conn.outside_id);
-    int out_port = conn.outside_port;
-
-    if (!out_node) {
-      UpdateToEngineNode(out_node, out_name, out_port, infos, pos, engine_nodes,
-                         false);
+    tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id);
+    int port = conn.outside_port;
+    if (!output_node) {
+      UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
+                         conn.outside_node_name, &output_node, &port);
     }
-
     VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
-            << " to " << out_node->name() << ":" << out_port;
-
-    if (conn.is_control_edge) {
-      graph->AddControlEdge(engine_node, out_node);
+            << " to " << output_node->name() << ":" << port;
+    if (conn.is_control_edge()) {
+      QCHECK_EQ(Graph::kControlSlot, port);
+      graph->AddControlEdge(engine_node, output_node);
     } else {
       auto new_edge =
-          graph->AddEdge(engine_node, conn.port_number, out_node, out_port);
-      CHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
-                      << ":" << conn.port_number << " -> " << out_node->name()
-                      << ":" << conn.outside_port;
+          graph->AddEdge(engine_node, conn.port_number, output_node, port);
+      QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
+                       << ":" << conn.port_number << " -> "
+                       << output_node->name() << ":" << conn.outside_port;
     }
   }
   return status;
@@ -1077,19 +969,21 @@
       LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
     }
     cudaSetDevice(cuda_device_id);
-    auto status = CreateTRTNode(&graph, engine_segments, i, device_alloc.second,
-                                params.max_batch_size, engine_nodes);
+    auto status = CreateTRTNode(engine_segments, i, params.max_batch_size,
+                                &graph, alloc.get(), &engine_nodes);
     // If status is ok, we successfully added the node to the graph and can
     // remove segment ops. Otherwise graph is not modified.
+    const string msg = StrCat(
+        "Engine ", engine.engine_name, " creation for segment ", i,
+        ", composed of ", converted_segments.at(i).first.size(), " nodes");
     if (status.ok()) {
+      LOG(INFO) << msg << " succeeded.";
       for (auto node_name : converted_segments.at(i).first) {
         graph.RemoveNode(node_map.at(node_name));
       }
     } else {
       // Graph is not modified.
-      LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
-                   << converted_segments.at(i).first.size()
-                   << " nodes failed: " << status << ". Skipping...";
+      LOG(WARNING) << msg << " failed: " << status << ". Skipping...";
     }
   }
   cudaSetDevice(old_cuda_device);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 451d6fe..3b0ac43 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -22,6 +22,7 @@
 #include <memory>
 #include <set>
 #include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -2788,6 +2789,7 @@
 tensorflow::Status ConvertSegmentToGraphDef(
     const tensorflow::Graph* graph,
     const tensorflow::grappler::GraphProperties& graph_properties,
+    const std::set<string>& subgraph_node_names,
     const std::vector<int>& subgraph_node_ids,  // In topological order
     std::vector<EngineConnection>* connections,
     tensorflow::GraphDef* segment_def, string* common_scope) {
@@ -2796,6 +2798,7 @@
   // nodes in the segment graphdef.
   for (size_t i = 0; i < connections->size(); ++i) {
     auto& connection = connections->at(i);
+    if (connection.is_control_edge()) continue;
     auto outside_node = graph->FindNodeId(connection.outside_id);
     if (!outside_node) {
       // This should never happen, unless the original graph is problematic.
@@ -2809,13 +2812,13 @@
       GetInputProperties(graph_properties,
                          graph->FindNodeId(connection.outside_id),
                          connection.outside_port, &partial_shape, &dtype);
-
+      connection.outside_shape = partial_shape;
     } else {
       GetOutputProperties(graph_properties,
                           graph->FindNodeId(connection.outside_id),
                           connection.outside_port, &partial_shape, &dtype);
+      connection.inside_shape = partial_shape;
     }
-    connection.outside_shape = partial_shape;
     connection.connection_type = dtype;
 
     // Add dummy input/output nodes to the segment graphdef.
@@ -2873,7 +2876,7 @@
   // Update the inputs of the new input nodes to point to placeholder nodes.
   for (int i = 0; i < connections->size(); ++i) {
     auto& connection = connections->at(i);
-    if (!connection.is_input_edge) continue;
+    if (connection.is_control_edge() || !connection.is_input_edge) continue;
     auto snode =
         segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
     const string placeholder_name =
@@ -2883,6 +2886,38 @@
             << placeholder_name;
     snode->set_input(connection.inside_port, placeholder_name);
   }
+  // Remove control inputs that are not inside the segment.
+  for (int i = 0; i < segment_def->node_size(); ++i) {
+    auto snode = segment_def->mutable_node(i);
+    const int input_size = snode->input_size();
+    int input_idx = 0;
+    int actual_input_idx = 0;
+    while (input_idx < input_size) {
+      TensorId input = ParseTensorName(snode->input(input_idx));
+      if (!subgraph_node_names.count(
+              string(input.first.data(), input.first.size())) &&
+          !str_util::StartsWith(input.first, kInputPHName)) {
+        if (input.second == Graph::kControlSlot) {
+          VLOG(2) << "... removing control inputs " << input.first
+                  << " from subgraph.";
+          ++input_idx;
+          continue;
+        } else {
+          return tensorflow::errors::InvalidArgument(
+              "Found non control input outside the segment that is not an "
+              "engine connection to ", snode->name(), ": ", input.first);
+        }
+      }
+      if (actual_input_idx != input_idx) {
+        snode->set_input(actual_input_idx, snode->input(input_idx));
+      }
+      ++input_idx;
+      ++actual_input_idx;
+    }
+    for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
+      snode->mutable_input()->RemoveLast();
+    }
+  }
   *common_scope = local_scope;
   VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
   return tensorflow::Status::OK();
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index d41a886..328efbf 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -36,8 +36,8 @@
 
 namespace tensorflow {
 namespace tensorrt {
-static const char* kInputPHName = "InputPH_";
-static const char* kOutputPHName = "OutputPH_";
+static const char* kInputPHName = "TensorRTInputPH_";
+static const char* kOutputPHName = "TensorRTOutputPH_";
 namespace convert {
 
 // TODO(aaroey): use an enum instead.
@@ -46,9 +46,10 @@
 const int INT8MODE = 2;
 
 struct EngineConnection {
+  // Constructs a non-control edge.
   EngineConnection(const string& outside, int out_id, int out_port,
                    const string& inside, int in_id, int in_port,
-                   bool input_edge, int port, bool control_edge = false)
+                   bool input_edge, int port)
       : outside_node_name(outside),
         outside_id(out_id),
         outside_port(out_port),
@@ -56,24 +57,39 @@
         inside_id(in_id),
         inside_port(in_port),
         is_input_edge(input_edge),
-        is_control_edge(control_edge),
         port_number(port) {}
 
+  // Constructs a control edge.
+  EngineConnection(const string& outside, int out_id, const string& inside,
+                   int in_id, bool input_edge)
+      : outside_node_name(outside),
+        outside_id(out_id),
+        outside_port(Graph::kControlSlot),
+        inside_node_name(inside),
+        inside_id(in_id),
+        inside_port(Graph::kControlSlot),
+        is_input_edge(input_edge),
+        port_number(Graph::kControlSlot) {}
+
+  bool is_control_edge() const {
+    return port_number == Graph::kControlSlot;
+  }
+
   const string outside_node_name;
   const int outside_id;
   const int outside_port;
-  tensorflow::PartialTensorShape outside_shape;
+  tensorflow::PartialTensorShape outside_shape;  // Only set for input edge.
 
   const string inside_node_name;
   const int inside_id;
   const int inside_port;
-  tensorflow::PartialTensorShape inside_shape;
+  tensorflow::PartialTensorShape inside_shape;  // Only set for output edge.
 
   tensorflow::DataType connection_type;
-  bool is_input_edge;
-  bool is_control_edge;
-  // The port number of the TRT node connecting to this edge.
-  int port_number;
+  const bool is_input_edge;
+
+  // The port number of the TRT node connected with this edge.
+  const int port_number;
 };
 
 struct EngineInfo {
@@ -86,7 +102,9 @@
   string device;
   tensorflow::GraphDef segment_graph_def;
 
-  // The segment nodes that are on one side of the edges are topological sorted.
+  // Non-control input connections inside this vector are sorted in a way such
+  // that, the segment nodes connecting to them are topological sorted.
+  // In addition, for non-control connections, there must be no duplicates.
   std::vector<EngineConnection> connections;
 
   enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
@@ -102,6 +120,7 @@
 // (OutputPH_*). This function needs to be called before TensorRT nodes
 // inserted in order to correctly get sizes from the original graph.
 //
+// - subgraph_node_names: the node names of the subgraph.
 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in
 //   topological order.
 // - segment_def: the output GraphDef, whose non-input/output nodedefs will be
@@ -111,6 +130,7 @@
 tensorflow::Status ConvertSegmentToGraphDef(
     const tensorflow::Graph* graph,
     const tensorflow::grappler::GraphProperties& graph_properties,
+    const std::set<string>& subgraph_node_names,
     const std::vector<int>& subgraph_node_ids,
     std::vector<EngineConnection>* connections,
     tensorflow::GraphDef* segment_def, string* common_scope);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 6699b71..a19cd24 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
 #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
 #include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -179,7 +180,7 @@
   helper->Ref();  // Increment count for calculating native graph
   VLOG(1) << "Executing native segment " << name();
   lib->Run(opts, native_func_, inputs, outputs,
-           [ctx, outputs, helper](const tensorflow::Status& s) {
+           [this, ctx, outputs, helper](const tensorflow::Status& s) {
              tensorflow::core::ScopedUnref sc(helper);
              VLOG(1) << "Native Segment completed";
              if (!s.ok()) {
@@ -189,6 +190,8 @@
              for (size_t t = 0; t < outputs->size(); ++t) {
                ctx->set_output(t, outputs->at(t));
              }
+             test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
+                                "done");
              delete outputs;
            });
 }
@@ -234,6 +237,7 @@
                                                 ->implementation()
                                                 ->GpuStreamMemberHack()));
   calib_res->calibrator_->setBatch(input_data, *stream);
+  test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
   VLOG(2) << "Passed calibration data";
   ExecuteNativeSegment(ctx, helper);
 }
@@ -258,7 +262,7 @@
           StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
                  ", current entries=");
       for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
-      StrAppend(&msg, "Requested batch=", num_batch);
+      StrAppend(&msg, " requested batch=", num_batch);
       LOG(WARNING) << msg;
       return -1;
     }
@@ -276,7 +280,8 @@
   }
   const int smallest_engine = GetEngineBatch(ctx);
   if (smallest_engine < 0) {
-    LOG(WARNING) << "Failed to get engine batch, running native segment";
+    LOG(WARNING) << "Failed to get engine batch, running native segment for "
+                 << name();
     ExecuteNativeSegment(ctx, helper);
     return;
   }
@@ -286,14 +291,15 @@
   auto& trt_engine_ptr = engine_ctx_pair.first;
   if (!trt_engine_ptr) {
     LOG(WARNING) << "Engine retrieval for batch size " << num_batch
-                 << " failed. Running native segment";
+                 << " failed. Running native segment for " << name();
     ExecuteNativeSegment(ctx, helper);
     return;
   }
   const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
                                       engine_ctx_pair.second.get());
   if (retry) {
-    LOG(WARNING) << "Failed to execute engine, retrying with native segment";
+    LOG(WARNING) << "Failed to execute engine, "
+                 << "retrying with native segment for " << name();
     ExecuteNativeSegment(ctx, helper);
     return;
   }
@@ -412,6 +418,7 @@
     LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
     return kRetry;
   }
+  test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
   // Synchronization will be done by TF.
   return !kRetry;
 }
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
index fe4fa16..7cdfe2b 100644
--- a/tensorflow/contrib/tensorrt/python/__init__.py
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -20,7 +20,11 @@
 
 # pylint: disable=unused-import,line-too-long
 from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value
 from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values
 from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value
+from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value
 from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled
 # pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 2b67931..5c1f4a4 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -20,9 +20,13 @@
 
 # pylint: disable=unused-import,line-too-long
 import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
 from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
+from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value
 from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
 from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
+from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
 from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
 from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
 from tensorflow.core.framework import graph_pb2
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index edd30ad..9d14e63 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -20,17 +20,19 @@
 
 import numpy as np
 
+from tensorflow.contrib.tensorrt.python import trt_convert
 from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.platform import test
 
 
-class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
 
   def GetParams(self):
     """Create a graph containing single segment."""
@@ -65,13 +67,17 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=1,
+        # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+        # breaks the connection check, fix it.
+        # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+        #   "relu", "identity", "max_pool"]
+        expected_engines=["my_trt_op_0"],
         expected_output_dims=(100, 6, 6, 6),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
 
 
-class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
 
   def GetParams(self):
     """Create a graph containing multiple segment."""
@@ -95,32 +101,138 @@
             padding="SAME",
             name="conv")
         c1 = constant_op.constant(
-            np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
-        p = conv * c1
+            np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1")
+        p = math_ops.mul(conv, c1, name="mul")
         c2 = constant_op.constant(
-            np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
-        q = conv / c2
+            np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2")
+        q = math_ops.div(conv, c2, name="div")
 
-        edge = self.trt_incompatible_op(q)
-        edge /= edge
-        r = edge + edge
+        edge = self.trt_incompatible_op(q, name="incompatible")
+        edge = math_ops.div(edge, edge, name="div1")
+        r = math_ops.add(edge, edge, name="add")
 
-        p -= edge
-        q *= edge
-        s = p + q
-        s -= r
+        p = math_ops.sub(p, edge, name="sub")
+        q = math_ops.mul(q, edge, name="mul1")
+        s = math_ops.add(p, q, name="add1")
+        s = math_ops.sub(s, r, name="sub1")
       array_ops.squeeze(s, name=self.output_name)
     return trt_test.TfTrtIntegrationTestParams(
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=2,
+        # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+        # breaks the connection check, fix it.
+        # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", "add",
+        #   "sub1"];
+        # - my_trt_op_1 should have ["weights","conv", "div"]
+        expected_engines=["my_trt_op_0", "my_trt_op_1"],
         expected_output_dims=(100, 12, 12, 6),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
 
 
-# TODO(aaroey): add a large complex graph to test.
+class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
+
+  def setUp(self):
+    """Setup method."""
+    super(PartiallyConvertedTestA, self).setUp()
+    # Let it fail to build the first engine.
+    trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
+
+  def GetParams(self):
+    """Create a graph containing two segment."""
+    input_name = "input"
+    input_dims = [2, 32, 32, 3]
+    g = ops.Graph()
+    with g.as_default():
+      inp = array_ops.placeholder(
+          dtype=dtypes.float32, shape=input_dims, name=input_name)
+      with g.device("/GPU:0"):
+        n = inp
+        for i in range(2):
+          c = constant_op.constant(1.0, name="c%d" % i)
+          n = math_ops.add(n, c, name="add%d" % i)
+          n = math_ops.mul(n, n, name="mul%d" % i)
+        edge = self.trt_incompatible_op(n, name="incompatible")
+        with g.control_dependencies([edge]):
+          c = constant_op.constant(1.0, name="c2")
+          n = math_ops.add(n, c, name="add2")
+        n = math_ops.mul(n, n, name="mul2")
+        c = constant_op.constant(1.0, name="c3")
+        n = math_ops.add(n, c, name="add3")
+        n = math_ops.mul(n, n, name="mul3")
+      array_ops.squeeze(n, name=self.output_name)
+    return trt_test.TfTrtIntegrationTestParams(
+        gdef=g.as_graph_def(),
+        input_names=[input_name],
+        input_dims=[input_dims],
+        expected_engines={
+            # Only the second engine is built.
+            "my_trt_op_1": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+        },
+        expected_output_dims=tuple(input_dims),
+        allclose_atol=1.e-06,
+        allclose_rtol=1.e-06)
+
+
+class PartiallyConvertedTestB(PartiallyConvertedTestA):
+
+  def setUp(self):
+    """Setup method."""
+    super(PartiallyConvertedTestB, self).setUp()
+    # Let it fail to build the second engine.
+    trt_convert.clear_test_values("")
+    trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail")
+
+  def GetParams(self):
+    """Create a graph containing two segment."""
+    return super(PartiallyConvertedTestB, self).GetParams()._replace(
+        expected_engines={
+            # Only the first engine is built.
+            "my_trt_op_0": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+        })
+
+
+class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
+
+  def GetParams(self):
+    """Create a graph containing multiple segment."""
+    input_name = "input"
+    input_dims = [2, 32, 32, 3]
+    g = ops.Graph()
+    with g.as_default():
+      inp = array_ops.placeholder(
+          dtype=dtypes.float32, shape=input_dims, name=input_name)
+      with g.device("/GPU:0"):
+        n = inp
+        c = constant_op.constant(1.0, name="c")
+        # Adds control dependency from the constant op to a trt incompatible op,
+        # and adds control dependency from the trt incompatible op to all other
+        # ops, to make sure the constant op cannot be contracted with any trt
+        # segment that depends on it.
+        with g.control_dependencies([c]):
+          d = self.trt_incompatible_op(n, name="incompatible")
+        with g.control_dependencies([d]):
+          n = math_ops.add(n, c, name="add")
+          n = math_ops.mul(n, n, name="mul")
+          n = math_ops.add(n, n, name="add1")
+        n = self.trt_incompatible_op(n, name="incompatible1")
+        with g.control_dependencies([d]):
+          n = math_ops.add(n, c, name="add2")
+          n = math_ops.mul(n, n, name="mul1")
+          n = math_ops.add(n, n, name="add3")
+      array_ops.squeeze(n, name=self.output_name)
+    return trt_test.TfTrtIntegrationTestParams(
+        gdef=g.as_graph_def(),
+        input_names=[input_name],
+        input_dims=[input_dims],
+        expected_engines={
+            "my_trt_op_0": ["add2", "add3", "mul1"],
+            "my_trt_op_1": ["add", "add1", "mul"]
+        },
+        expected_output_dims=tuple(input_dims),
+        allclose_atol=1.e-06,
+        allclose_rtol=1.e-06)
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 730b684..2e1107e 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -66,7 +66,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name, w1_name, w2_name],
         input_dims=[input_dims, w1_dims, w2_dims],
-        num_expected_engines=1,
+        expected_engines=["my_trt_op_0"],
         expected_output_dims=(12, 5, 8, 7),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 0c03a10..8be32f5 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -102,7 +102,10 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=7,
+        expected_engines=[
+            "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+            "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
+        ],
         expected_output_dims=(48, 89),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index dd67346..9316b14 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -109,7 +109,24 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=16,
+        expected_engines=[
+            "my_trt_op_0",
+            "my_trt_op_1",
+            "my_trt_op_2",
+            "my_trt_op_3",
+            "my_trt_op_4",
+            "my_trt_op_5",
+            "my_trt_op_6",
+            "my_trt_op_7",
+            "my_trt_op_8",
+            "my_trt_op_9",
+            "my_trt_op_10",
+            "my_trt_op_11",
+            "my_trt_op_12",
+            "my_trt_op_13",
+            "my_trt_op_14",
+            "my_trt_op_15",
+        ],
         expected_output_dims=(5, 23040),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 8c51c45..1874b9d 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -73,7 +73,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=1,
+        expected_engines=["my_trt_op_0"],
         expected_output_dims=(2, 126),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 97b29bf05..8c59000 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -58,7 +58,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=1,
+        expected_engines=['my_trt_op_0'],
         expected_output_dims=(5, 12, 12, 1),
         allclose_atol=1.e-02,
         allclose_rtol=1.e-02)
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index 734ccf6..fd55b8c 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -77,7 +77,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=2,
+        expected_engines=["my_trt_op_0", "my_trt_op_1"],
         expected_output_dims=(2, 4, 5, 4),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 50265c0..97e0d23 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -59,7 +59,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=2,
+        expected_engines=["my_trt_op_0", "my_trt_op_1"],
         expected_output_dims=(2, 4, 5, 4),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index bb7f5a7..5968af2 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,6 +30,7 @@
 # pylint: enable=unused-import
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import graph_io
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
@@ -37,10 +38,14 @@
 from tensorflow.python.platform import tf_logging as logging
 
 TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
-    "gdef", "input_names", "input_dims", "num_expected_engines",
+    "gdef", "input_names", "input_dims", "expected_engines",
     "expected_output_dims", "allclose_atol", "allclose_rtol"
 ])
 
+RunParams = namedtuple(
+    "RunParams",
+    ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+
 PRECISION_MODES = ["FP32", "FP16", "INT8"]
 
 
@@ -48,6 +53,12 @@
   return mode == "INT8"
 
 
+class GraphState:
+  ORIGINAL = 0
+  CALIBRATE = 1
+  INFERENCE = 2
+
+
 class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
   """Class to test Tensorflow-TensorRT integration."""
 
@@ -63,34 +74,79 @@
   def precision_modes(self):
     return ["FP32", "FP16", "INT8"]
 
+  # str is bytes in py2, but unicode in py3.
+  def _ToUnicode(self, s):
+    if six.PY2:
+      if isinstance(s, unicode):
+        return s
+      return s.decode("utf-8")
+    else:
+      if isinstance(s, str):
+        return s
+      return s.decode("utf-8")
+
   def _ToBytes(self, s):
     if six.PY2:
+      if isinstance(s, unicode):
+        return s.encode("utf-8")
       return s
     else:
-      return s.encode("utf-8")
+      if isinstance(s, str):
+        return s.encode("utf-8")
+      return s
 
   def _ToString(self, s):
     if six.PY2:
+      if isinstance(s, unicode):
+        return s.encode("utf-8")
       return s
     else:
+      if isinstance(s, str):
+        return s
       return s.decode("utf-8")
 
+  @classmethod
+  def setUpClass(cls):
+    """Setup method for the module."""
+    super(TfTrtIntegrationTestBase, cls).setUpClass()
+    trt_convert.enable_test_value()
+
   def setUp(self):
     """Setup method."""
     super(TfTrtIntegrationTestBase, self).setUp()
     warnings.simplefilter("always")
+    trt_convert.clear_test_values("")
 
   def GetParams(self):
     """Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
     raise NotImplementedError()
 
-  def _GetConfigProto(self,
-                      params,
-                      use_optimizer,
-                      precision_mode=None,
-                      is_dynamic_op=None):
+  def _PrepareRun(self, params, graph_state):
+    """Set up necessary testing environment before calling sess.run()."""
+    # Clear test values added by TRTEngineOp.
+    trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
+    trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
+    trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
+
+  def _VerifyRun(self, params, graph_state):
+    """Verify the state after sess.run()."""
+    for engine_name in params.expected_engines:
+      if graph_state == GraphState.ORIGINAL:
+        self._ExpectCalibration(engine_name, "")
+        self._ExpectNativeSegment(engine_name, "")
+        self._ExpectTrtEngine(engine_name, "")
+      elif graph_state == GraphState.CALIBRATE:
+        self._ExpectCalibration(engine_name, "done")
+        self._ExpectNativeSegment(engine_name, "done")
+        self._ExpectTrtEngine(engine_name, "")
+      elif graph_state == GraphState.INFERENCE:
+        self._ExpectCalibration(engine_name, "")
+        self._ExpectNativeSegment(engine_name, "")
+        self._ExpectTrtEngine(engine_name, "done")
+
+  def _GetConfigProto(self, params, run_params, graph_state):
     """Get config proto based on specific settings."""
-    if use_optimizer:
+    if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
       rewriter_cfg = rewriter_config_pb2.RewriterConfig()
       rewriter_cfg.optimizers.extend(["constfold", "layout"])
       custom_op = rewriter_cfg.custom_optimizers.add()
@@ -98,14 +154,31 @@
       custom_op.parameter_map["minimum_segment_size"].i = 3
       custom_op.parameter_map["max_batch_size"].i = max(
           [dims[0] for dims in params.input_dims])
-      custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
+      custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
       custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
       custom_op.parameter_map["precision_mode"].s = self._ToBytes(
-          precision_mode)
+          run_params.precision_mode)
       graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
     else:
       graph_options = config_pb2.GraphOptions()
 
+    # Disable all other optimizations which can affect the converted graph.
+    off = rewriter_config_pb2.RewriterConfig.OFF
+    graph_options.optimizer_options.opt_level = config_pb2.OptimizerOptions.L0
+    graph_options.rewrite_options.layout_optimizer = off
+    graph_options.rewrite_options.constant_folding = off
+    graph_options.rewrite_options.shape_optimization = off
+    graph_options.rewrite_options.remapping = off
+    graph_options.rewrite_options.arithmetic_optimization = off
+    graph_options.rewrite_options.dependency_optimization = off
+    graph_options.rewrite_options.loop_optimization = off
+    graph_options.rewrite_options.function_optimization = off
+    graph_options.rewrite_options.debug_stripper = off
+    graph_options.rewrite_options.disable_model_pruning = True
+    graph_options.rewrite_options.scoped_allocator_optimization = off
+    graph_options.rewrite_options.memory_optimization = (
+        rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
+
     gpu_options = config_pb2.GPUOptions()
     gpu_options.allow_growth = True
     if trt_convert.get_linked_tensorrt_version()[0] == 3:
@@ -115,7 +188,21 @@
         gpu_options=gpu_options, graph_options=graph_options)
     return config
 
-  def _RunGraph(self, params, gdef, input_data, config, num_runs=2):
+  def _ExpectTestValue(self, engine_name, method, value):
+    self.assertEqual(
+        value, trt_convert.get_test_value("%s:%s" % (engine_name, method)))
+
+  def _ExpectCalibration(self, engine_name, value):
+    self._ExpectTestValue(engine_name, "ExecuteCalibration", value)
+
+  def _ExpectTrtEngine(self, engine_name, value):
+    self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value)
+
+  def _ExpectNativeSegment(self, engine_name, value):
+    self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
+
+  def _RunGraph(self, params, gdef, input_data, config, graph_state,
+                num_runs=2):
     """Run given graphdef multiple times."""
     assert len(params.input_names) == len(input_data)
     g = ops.Graph()
@@ -132,93 +219,166 @@
       val = None
       # Defaults to 2 runs to verify result across multiple runs is same.
       for _ in range(num_runs):
+        self._PrepareRun(params, graph_state)
         new_val = sess.run(out,
                            {inp[i]: input_data[i] for i in range(len(inp))})
         self.assertEqual(params.expected_output_dims, new_val.shape)
         if val is not None:
           self.assertAllEqual(val, new_val)
         val = new_val
+        self._VerifyRun(params, graph_state)
     return val
 
   # Use real data that is representative of the inference dataset
   # for calibration. For this test script it is random data.
   def _RunCalibration(self, params, gdef, input_data, config):
     """Run calibration on given graph."""
-    return self._RunGraph(params, gdef, input_data, config, 30)
+    return self._RunGraph(
+        params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
 
-  def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op):
+  def _GetTrtGraphDef(self, params, run_params, gdef):
     """Return trt converted graphdef."""
     return trt_convert.create_inference_graph(
         input_graph_def=gdef,
         outputs=[self.output_name],
         max_batch_size=max([dims[0] for dims in params.input_dims]),
         max_workspace_size_bytes=1 << 25,
-        precision_mode=precision_mode,
+        precision_mode=run_params.precision_mode,
         minimum_segment_size=2,
-        is_dynamic_op=is_dynamic_op)
+        is_dynamic_op=run_params.dynamic_engine)
 
-  def _VerifyGraphDef(self,
-                      params,
-                      gdef,
-                      precision_mode=None,
-                      is_calibrated=None,
-                      dynamic_engine=None):
+  def _WriteGraph(self, params, run_params, gdef, graph_state):
+    if graph_state == GraphState.ORIGINAL:
+      label = "Original"
+    elif graph_state == GraphState.CALIBRATE:
+      label = "CalibEngine"
+    elif graph_state == GraphState.INFERENCE:
+      label = "InferEngine"
+    graph_name = (
+        self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
+        ".pbtxt")
+    logging.info("Writing graph to %s/%s", self.get_temp_dir(), graph_name)
+    graph_io.write_graph(gdef, self.get_temp_dir(), graph_name)
+
+  def _VerifyConnections(self, params, converted_gdef):
+    old_to_new_node_map = {
+        self._ToString(n.name): self._ToString(n.name) for n in params.gdef.node
+    }
+    for engine_name, node_names in params.expected_engines.iteritems():
+      for n in node_names:
+        old_to_new_node_map[n] = engine_name
+    name_to_node_map = {self._ToString(n.name): n for n in params.gdef.node}
+
+    def input_name(inp):
+      inp = self._ToString(inp)
+      prefix = ""
+      if inp[0] == "^":
+        prefix = "^"
+        inp = inp[1:]
+      parts = inp.split(":")
+      if len(parts) > 1 and parts[-1].isdigit():
+        inp = inp[:-len(parts[-1]) - 1]
+      return (prefix, inp)
+
+    expected_input_map = {}
+    for n in params.gdef.node:
+      name_str = self._ToString(n.name)
+      target_node_name = old_to_new_node_map[name_str]
+      is_engine_op = (target_node_name != name_str)
+      if target_node_name not in expected_input_map:
+        expected_input_map[target_node_name] = set()
+      input_set = expected_input_map[target_node_name]
+      for inp in n.input:
+        (prefix, inp_name) = input_name(inp)
+        # Add the input only if it's outside the segment (note that it could be
+        # in a different engine).
+        if (not is_engine_op or
+            old_to_new_node_map[inp_name] != target_node_name):
+          if is_engine_op and name_to_node_map[inp_name].op == "Const":
+            # Const data input nodes to the segment has been copied to the
+            # segment graphdef and the engine, and the dependency has been
+            # converted to control dependendy.
+            input_set.add("^" + old_to_new_node_map[inp_name])
+          else:
+            input_set.add(prefix + old_to_new_node_map[inp_name])
+
+    actual_input_map = {}
+    for n in converted_gdef.node:
+      name_str = self._ToString(n.name)
+      actual_input_map[name_str] = set()
+      input_set = actual_input_map[name_str]
+      for inp in n.input:
+        (prefix, node_name) = input_name(inp)
+        input_set.add(prefix + node_name)
+
+    self.assertEqual(
+        expected_input_map,
+        actual_input_map,
+        msg="expected:\n%s\nvs actual:\n%s" % (expected_input_map,
+                                               actual_input_map))
+
+  def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
+    self._WriteGraph(params, run_params, gdef, graph_state)
+
     num_engines = 0
     for n in gdef.node:
-      # TODO(jie): we should have coverage for failed conversion (TF fallback).
-      # where the conversion will fail and we shouldn't count this engine as the
-      # converted engines.
       if n.op == "TRTEngineOp":
         num_engines += 1
-        self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s)
-        self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s)
+        self.assertTrue(n.name in params.expected_engines)
+        self.assertTrue(len(n.attr["serialized_segment"].s))
+        self.assertTrue(len(n.attr["segment_funcdef_name"].s))
         self.assertEqual(
-            self._ToBytes(precision_mode), n.attr["precision_mode"].s)
-        self.assertEqual(not dynamic_engine, n.attr["static_engine"].b)
-        if _IsQuantizationMode(precision_mode) and is_calibrated:
-          self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s)
+            self._ToBytes(run_params.precision_mode),
+            n.attr["precision_mode"].s)
+
+        is_dynamic_engine = not n.attr["static_engine"].b
+        self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+
+        has_calibration_data = len(n.attr["calibration_data"].s)
+        if (_IsQuantizationMode(run_params.precision_mode) and
+            graph_state == GraphState.INFERENCE):
+          self.assertTrue(has_calibration_data)
         else:
-          self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s)
-    if precision_mode is None:  # This means gdef is the original GraphDef.
+          self.assertFalse(has_calibration_data)
+    if graph_state == GraphState.ORIGINAL:
       self.assertEqual(0, num_engines)
     else:
-      self.assertEqual(num_engines, params.num_expected_engines)
+      self.assertEqual(num_engines, len(params.expected_engines))
+      if isinstance(params.expected_engines, dict):
+        self._VerifyConnections(params, gdef)
+      # TODO(aaroey): consider verifying the corresponding TF function.
 
-  def RunTest(self, params, use_optimizer, precision_mode,
-              dynamic_infer_engine, dynamic_calib_engine):
-    assert precision_mode in PRECISION_MODES
+  def RunTest(self, params, run_params):
+    assert run_params.precision_mode in PRECISION_MODES
     input_data = [np.random.random_sample(dims) for dims in params.input_dims]
     input_gdef = params.gdef
-    self._VerifyGraphDef(params, input_gdef)
+    self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
 
     # Get reference result without running trt.
-    config_no_trt = self._GetConfigProto(params, False)
+    config_no_trt = self._GetConfigProto(params, run_params,
+                                         GraphState.ORIGINAL)
     logging.info("Running original graph w/o trt, config:\n%s",
                  str(config_no_trt))
-    ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt)
+    ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
+                                GraphState.ORIGINAL)
 
     # Run calibration if necessary.
-    if _IsQuantizationMode(precision_mode):
+    if _IsQuantizationMode(run_params.precision_mode):
 
-      calib_config = self._GetConfigProto(params, use_optimizer, precision_mode,
-                                          dynamic_calib_engine)
+      calib_config = self._GetConfigProto(params, run_params,
+                                          GraphState.CALIBRATE)
       logging.info("Running calibration graph, config:\n%s", str(calib_config))
-      if use_optimizer:
-        self.assertTrue(False)
-        # TODO(aaroey): uncomment this and get infer_gdef when this mode is
-        # supported.
-        # result = self._RunCalibration(params, input_gdef, input_data,
-        #                               calib_config)
+      if run_params.use_optimizer:
+        result = self._RunCalibration(params, input_gdef, input_data,
+                                      calib_config)
       else:
-        calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode,
-                                          dynamic_calib_engine)
-        self._VerifyGraphDef(params, calib_gdef, precision_mode, False,
-                             dynamic_calib_engine)
+        calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
+        self._VerifyGraphDef(params, run_params, calib_gdef,
+                             GraphState.CALIBRATE)
         result = self._RunCalibration(params, calib_gdef, input_data,
                                       calib_config)
-        infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
-        self._VerifyGraphDef(params, infer_gdef, precision_mode, True,
-                             dynamic_calib_engine)
+      infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
+      self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
 
       self.assertAllClose(
           ref_result,
@@ -229,18 +389,19 @@
       infer_gdef = input_gdef
 
     # Run inference.
-    infer_config = self._GetConfigProto(params, use_optimizer, precision_mode,
-                                        dynamic_infer_engine)
+    infer_config = self._GetConfigProto(params, run_params,
+                                        GraphState.INFERENCE)
     logging.info("Running final inference graph, config:\n%s",
                  str(infer_config))
-    if use_optimizer:
-      result = self._RunGraph(params, infer_gdef, input_data, infer_config)
+    if run_params.use_optimizer:
+      result = self._RunGraph(params, infer_gdef, input_data, infer_config,
+                              GraphState.INFERENCE)
     else:
-      trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode,
-                                            dynamic_infer_engine)
-      self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True,
-                           dynamic_infer_engine)
-      result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config)
+      trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
+      self._VerifyGraphDef(params, run_params, trt_infer_gdef,
+                           GraphState.INFERENCE)
+      result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
+                              GraphState.INFERENCE)
 
     self.assertAllClose(
         ref_result,
@@ -263,66 +424,44 @@
 def _AddTests(test_class):
   """Adds test methods to TfTrtIntegrationTestBase."""
 
-  def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
-               dynamic_calib_engine):
+  def _GetTest(run_params):
     """Gets a single test method based on the parameters."""
 
     def _Test(self):
       params = self.GetParams()
       logging.info(
-          "Running test with parameters: use_optimizer=%s, precision_mode=%s, "
-          "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer,
-          precision_mode, dynamic_infer_engine, dynamic_calib_engine)
-      self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
-                   dynamic_calib_engine)
+          "Running test %s with parameters: use_optimizer=%s, "
+          "precision_mode=%s, dynamic_engine=%s",
+          "testTfTRT_" + run_params.test_name, run_params.use_optimizer,
+          run_params.precision_mode, run_params.dynamic_engine)
+      self.RunTest(params, run_params)
 
     return _Test
 
   use_optimizer_options = [False, True]
-  dynamic_infer_engine_options = [False, True]
-  dynamic_calib_engine_options = [False, True]
-  for (use_optimizer, precision_mode,
-       dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
-           use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options,
-           dynamic_calib_engine_options):
+  dynamic_engine_options = [False, True]
+  for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
+      use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
     if _IsQuantizationMode(precision_mode):
-      if not dynamic_calib_engine and dynamic_infer_engine:
-        # TODO(aaroey): test this case, the conversion from static calibration
-        # engine to dynamic inference engine should be a noop.
-        continue
       if use_optimizer:
         # TODO(aaroey): if use_optimizer is True we need to get the inference
         # graphdef using custom python wrapper class, which is not currently
         # supported yet.
         continue
-      if not dynamic_calib_engine:
+      if not dynamic_engine:
         # TODO(aaroey): construction of static calibration engine is not
         # supported yet.
         continue
-      if dynamic_calib_engine and not dynamic_infer_engine:
-        # TODO(aaroey): construction of static inference engine using dynamic
-        # calibration engine is not supported yet.
-        continue
-    else:  # In non int8 mode.
-      if dynamic_calib_engine:
-        # dynamic_calib_engine doesn't affect non-int8 modes, so just let
-        # related tests run once on dynamic_calib_engine=False.
-        continue
 
     conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
-    infer_engine_type = ("DynamicInferEngine"
-                         if dynamic_infer_engine else "StaticInferEngine")
-    calib_engine_type = ""
-    if precision_mode == "INT8":
-      calib_engine_type = ("DynamicCalibEngine"
-                           if dynamic_calib_engine else "StaticCalibEngine")
-    test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type,
-                                ("_" + calib_engine_type)
-                                if len(calib_engine_type) else "")
-    setattr(
-        test_class, "testTfTRT_" + test_name,
-        _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
-                 dynamic_calib_engine))
+    engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine")
+    test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type)
+    run_params = RunParams(
+        use_optimizer=use_optimizer,
+        precision_mode=precision_mode,
+        dynamic_engine=dynamic_engine,
+        test_name=test_name)
+    setattr(test_class, "testTfTRT_" + test_name, _GetTest(run_params))
 
 
 if trt_convert.is_tensorrt_enabled():
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index b9e977c..500057a 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -100,7 +100,10 @@
         gdef=g.as_graph_def(),
         input_names=[input_name, input2_name],
         input_dims=[input_dims, input2_dims],
-        num_expected_engines=5,
+        expected_engines=[
+            "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+            "my_trt_op_4"
+        ],
         expected_output_dims=(12, 5, 8, 12),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc
new file mode 100644
index 0000000..319ddea
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/test/utils.h"
+
+#include <unordered_map>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// TODO(aaroey): make this class thread-safe.
+class TestValueManager {
+ public:
+  static TestValueManager* singleton() {
+    static TestValueManager* manager = new TestValueManager();
+    return manager;
+  }
+
+  void Enable() {
+    VLOG(1) << "Enabling test value";
+    enabled_ = true;
+  }
+
+  void Add(const string& label, const string& value) {
+    if (TF_PREDICT_FALSE(enabled_)) {
+      QCHECK_NE("", value);
+      VLOG(1) << "Adding test value: " << label << " -> " << value;
+      values_.insert({label, value});
+    }
+  }
+
+  string Get(const string& label) {
+    if (TF_PREDICT_FALSE(enabled_)) {
+      VLOG(1) << "Getting test value by " << label;
+      auto itr = values_.find(label);
+      if (itr == values_.end()) return "";
+      return itr->second;
+    }
+    return "";
+  }
+
+  void Clear(const string& pattern) {
+    if (TF_PREDICT_FALSE(enabled_)) {
+      VLOG(1) << "Clearing test values";
+      if (pattern == "") {
+        values_.clear();
+        return;
+      }
+      std::vector<string> keys_to_clear;
+      for (const auto& kv : values_) {
+        if (RE2::FullMatch(kv.first, pattern)) {
+          keys_to_clear.push_back(kv.first);
+        }
+      }
+      for (const string& key : keys_to_clear) {
+        values_.erase(key);
+      }
+    }
+  }
+
+ private:
+  TestValueManager() : enabled_(false) {}
+
+  bool enabled_;
+  std::unordered_map<string, string> values_;
+};
+
+void EnableTestValue() { TestValueManager::singleton()->Enable(); }
+
+void ClearTestValues(const string& pattern) {
+  TestValueManager::singleton()->Clear(pattern);
+}
+
+void AddTestValue(const string& label, const string& value) {
+  TestValueManager::singleton()->Add(label, value);
+}
+
+string GetTestValue(const string& label) {
+  return TestValueManager::singleton()->Get(label);
+}
+
+}  // namespace test
+}  // namespace tensorrt
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h
new file mode 100644
index 0000000..625cd3d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// Helper methods to inject values used by testing tools.
+void EnableTestValue();
+void ClearTestValues(const string& pattern);
+void AddTestValue(const string& label, const string& value);
+string GetTestValue(const string& label);
+
+#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return)                       \
+  do {                                                                         \
+    if (::tensorflow::tensorrt::test::GetTestValue(label) == value_to_return) {\
+      return errors::Internal("Injected manually");                            \
+    }                                                                          \
+  } while(0)
+
+}  // namespace test
+}  // namespace tensorrt
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index 2b134c3..ab4d224 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -72,7 +72,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=1,
+        expected_engines=["my_trt_op_0"],
         expected_output_dims=(5, 6, 2, 2),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index bec2f23..56bdf84 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -63,7 +63,7 @@
         gdef=g.as_graph_def(),
         input_names=[input_name],
         input_dims=[input_dims],
-        num_expected_engines=1,
+        expected_engines=["my_trt_op_0"],
         expected_output_dims=(5, 2, 2, 6),
         allclose_atol=1.e-03,
         allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 422740f..921c263 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -101,6 +101,7 @@
 #include "tensorflow/core/util/stat_summarizer.h"
 #include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
 #include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
 %}
 
 %ignoreall
@@ -110,6 +111,10 @@
 %unignore get_linked_tensorrt_version;
 %unignore get_loaded_tensorrt_version;
 %unignore is_tensorrt_enabled;
+%unignore enable_test_value;
+%unignore clear_test_values;
+%unignore add_test_value;
+%unignore get_test_value;
 
 %{
 
@@ -251,6 +256,22 @@
   return tensorflow::tensorrt::IsGoogleTensorRTEnabled();
 }
 
+void enable_test_value() {
+  tensorflow::tensorrt::test::EnableTestValue();
+}
+
+void clear_test_values(string pattern) {
+  tensorflow::tensorrt::test::ClearTestValues(pattern);
+}
+
+void add_test_value(string label, string value) {
+  tensorflow::tensorrt::test::AddTestValue(label, value);
+}
+
+string get_test_value(string label) {
+  return tensorflow::tensorrt::test::GetTestValue(label);
+}
+
 %}
 
 std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
@@ -266,5 +287,9 @@
 version_struct get_linked_tensorrt_version();
 version_struct get_loaded_tensorrt_version();
 bool is_tensorrt_enabled();
+void enable_test_value();
+void clear_test_values(string pattern);
+void add_test_value(string label, string value);
+string get_test_value(string label);
 
 %unignoreall