Introducing TensortRT Operator to TF which can run (sub)graphs in
highly optimized TensorRT engines.  This commit is a merged version of
many commits by

   benbarsdell    <bbarsdell at nvidia.com>
   deadeyegoodwin <davidg at nvidia.com
   jjsjann123     <jiej at nvidia.com>
   samikama      <skama at  nvidia.com>
diff --git a/configure.py b/configure.py
index cf16ef4..580bbc0 100644
--- a/configure.py
+++ b/configure.py
@@ -37,12 +37,14 @@
 _TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                              'WORKSPACE')
 _DEFAULT_CUDA_VERSION = '9.0'
+_DEFAULT_TENSORRT_VERSION = '4'
 _DEFAULT_CUDNN_VERSION = '7'
 _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
 _DEFAULT_CUDA_PATH = '/usr/local/cuda'
 _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
 _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
                           'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
+_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
 _TF_OPENCL_VERSION = '1.2'
 _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
 _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@@ -382,13 +384,12 @@
 
   var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
   environ_cp[var_name] = var
-  if var == '1':
-    write_to_bazelrc('build --define %s=true' % option_name)
-  elif bazel_config_name is not None:
-    # TODO(mikecase): Migrate all users of configure.py to use --config Bazel
-    # options and not to set build configs through environment variables.
-    write_to_bazelrc('build:%s --define %s=true'
-                     % (bazel_config_name, option_name))
+  # TODO(mikecase): Migrate all users of configure.py to use --config Bazel
+  # options and not to set build configs through environment variables.
+  if var=='1':
+    setting='true'
+    confname=":%s"%(bazel_config_name) if bazel_config_name is not None else ""
+    write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting))
 
 
 def set_action_env_var(environ_cp,
@@ -438,13 +439,12 @@
   for seg in version_segments:
     if not seg.isdigit():
       return None
-
   version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
   return int(version_str)
 
 
 def check_bazel_version(min_version):
-  """Check installed bezel version is at least min_version.
+  """Check installed bazel version is at least min_version.
 
   Args:
     min_version: string for minimum bazel version.
@@ -1056,6 +1056,108 @@
       write_to_bazelrc('test --config=cuda')
 
 
+def set_tf_trt_version(environ_cp):
+  """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION."""
+  ask_trt_version = (
+      'Please specify the TensorRT (libnvinfer) version you want to use. '
+      '[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION
+
+  while True:
+    tf_trt_version = get_from_env_or_user_or_default(
+        environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version,
+        _DEFAULT_TENSORRT_VERSION)
+    # if library version is passed and known
+    default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX)
+    ask_trt_path = (r'Please specify the location where libnvinfer %s library is '
+                      'installed. Refer to README.md for more details. [Default'
+                      ' is %s]:') % (tf_trt_version, default_trt_path)
+    trt_install_path = get_from_env_or_user_or_default(
+        environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path)
+
+    # Result returned from "read" will be used unexpanded. That make "~"
+    # unusable. Going through one more level of expansion to handle that.
+    trt_install_path = os.path.realpath(
+        os.path.expanduser(trt_install_path))
+    # Simple function to search for libnvinfer in install path
+    # it will find all libnvinfer.so* in user defined install path
+    # and lib64 subdirectory and return absolute paths
+    def find_libs(search_path):
+      fl=set()
+      if os.path.exists(search_path) and os.path.isdir(search_path):
+        fl.update([os.path.realpath(os.path.join(search_path,x)) \
+                   for x in os.listdir(search_path) if 'libnvinfer.so' in x])
+      return fl
+    possible_files=find_libs(trt_install_path)
+    possible_files.update(find_libs(os.path.join(trt_install_path,'lib64')))
+    if is_linux():
+      cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$")
+      cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$")
+      def is_compatible(lib,cudaver,cudnnver):
+        ldd_bin=which('ldd') or '/usr/bin/ldd'
+        ldd_out=run_shell([ldd_bin,lib]).split(os.linesep)
+        for l in ldd_out:
+          if 'libcudnn.so' in l:
+            cudnn=cudnnpatt.search(l)
+          elif 'libcudart.so' in l:
+            cudart=cudapatt.search(l)
+        if cudnn:
+          cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0
+        if cudart:
+          cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0
+        return (cudnn==cudnnver) and (cudart==cudaver)
+      cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
+      cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
+      valid_libs=[]
+      vfinder=re.compile('.*libnvinfer.so.?(.*)$')
+      highest_ver=[0,None,None]
+
+      for l in possible_files:
+        if is_compatible(l,cudaver,cudnnver):
+          valid_libs.append(l)
+          vstr=vfinder.search(l).group(1)
+          currver=convert_version_to_int(vstr) if len(vstr) else 0
+          if currver > highest_ver[0]:
+            highest_ver= [currver,vstr,l]
+      if highest_ver[1] is not None:
+        trt_install_path=os.path.dirname(highest_ver[2])
+        tf_trt_version=highest_ver[1]
+        break
+      ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+      libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
+      libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)',
+                                           libnvinfer_path_from_ldconfig)
+      if libnvinfer_path_from_ldconfig:
+        libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1)
+        if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig,
+                                     tf_trt_version)):
+          trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
+          break
+
+    # Reset and Retry
+    if len(possible_files):
+      print(
+          'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions '
+           % tf_trt_version)
+      print(trt_install_path)
+      print(os.path.join(trt_install_path,'lib64'))
+    else:
+      print(
+          'Invalid path to TensorRT %s. No libnvinfer.so* files found in '
+          'found:' % tf_trt_version)
+      print(trt_install_path)
+      print(os.path.join(trt_install_path,'lib64'))
+      if is_linux():
+        print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version))
+
+    environ_cp['TF_TENSORRT_VERSION'] = ''
+
+  # Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION
+  environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
+  write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
+  environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version
+  write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version)
+  write_to_bazelrc('build:tensorrt --define using_tensorrt=true')
+
 def set_host_cxx_compiler(environ_cp):
   """Set HOST_CXX_COMPILER."""
   default_cxx_host_compiler = which('g++') or ''
@@ -1244,9 +1346,11 @@
     environ_cp['TF_NEED_COMPUTECPP'] = '0'
     environ_cp['TF_NEED_OPENCL'] = '0'
     environ_cp['TF_CUDA_CLANG'] = '0'
+    environ_cp['TF_NEED_TENSORRT'] = '0'
 
   if is_macos():
     environ_cp['TF_NEED_JEMALLOC'] = '0'
+    environ_cp['TF_NEED_TENSORRT'] = '0'
 
   set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
                 'with_jemalloc', True)
@@ -1301,6 +1405,10 @@
       if not is_windows():
         set_gcc_host_compiler_path(environ_cp)
     set_other_cuda_vars(environ_cp)
+    # enable tensorrt if desired. Disabled on non-linux
+    set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
+    if environ_cp.get('TF_NEED_TENSORRT') == '1':
+      set_tf_trt_version(environ_cp)
 
   set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
   if environ_cp.get('TF_NEED_MPI') == '1':
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index da37564..b374462 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -359,6 +359,14 @@
 )
 
 config_setting(
+    name = "using_tensorrt",
+    define_values = {
+        "using_tensorrt":"true",
+    },
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
     name = "with_mpi_support",
     values = {"define": "with_mpi_support=true"},
     visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 8bed0fa..e5c3017 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -7,6 +7,7 @@
 
 load("//third_party/mpi:mpi.bzl", "if_mpi")
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
 
 py_library(
     name = "contrib_py",
@@ -104,7 +105,9 @@
         "//tensorflow/contrib/training:training_py",
         "//tensorflow/contrib/util:util_py",
         "//tensorflow/python:util",
-    ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
+    ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"])
+    + if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
+
 )
 
 cc_library(
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
new file mode 100644
index 0000000..723c9f5
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -0,0 +1,266 @@
+# -*- python -*-
+# Description:
+#   provide tensorrt operators and converter package
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"])  # Apache 2.0
+
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_custom_op_library",
+    "tf_gen_op_libs",
+    "tf_gen_op_wrapper_py",
+    "tf_py_wrap_cc",
+    "tf_cc_test",
+    "tf_kernel_library",
+    "tf_custom_op_py_library",
+    "tf_copts",
+)
+
+
+
+tf_custom_op_library(
+    name = "python/ops/_trt_engine_op.so",
+    srcs = [
+        "kernels/trt_engine_op.cc",
+        "ops/trt_engine_op.cc",
+        "kernels/trt_engine_op.h",
+    ],
+    gpu_srcs = [],
+    deps = [
+        "@local_config_tensorrt//:tensorrt",
+        ":trt_shape_function",
+        "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/core/kernels:bounds_check_lib",
+        "//tensorflow/core/kernels:ops_util_hdrs",
+    ],
+)
+
+cc_library(
+    name = "trt_shape_function",
+    srcs=[
+        "shape_fn/trt_shfn.cc",
+    ],
+    hdrs=["shape_fn/trt_shfn.h"],
+    copts=tf_copts(),
+    deps=[
+        ":trt_logging",
+        "//third_party/eigen3",
+        "@local_config_tensorrt//:tensorrt",
+        "@protobuf_archive//:protobuf",
+        "@nsync//:nsync_headers",
+        "//tensorflow/core:framework_headers_lib",
+    ]
+)
+
+
+tf_kernel_library(
+    name = "trt_engine_op_kernel",
+    srcs = [
+        "kernels/trt_engine_op.cc",
+    ],
+    hdrs=[
+        "kernels/trt_engine_op.h",
+    ],
+    gpu_srcs = [
+    ],
+    deps = [
+        ":trt_logging",
+        ":trt_shape_function",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//third_party/eigen3",
+        "//tensorflow/core:gpu_headers_lib",
+        "@local_config_tensorrt//:tensorrt",
+        "//tensorflow/core:lib_proto_parsing",        
+    ],
+    alwayslink=1,
+)
+
+tf_gen_op_libs(
+   op_lib_names = [
+	"trt_engine_op",
+   ],
+    deps=[
+        "@local_config_tensorrt//:tensorrt",
+    ]
+)
+
+
+cc_library(
+    name="trt_logging",
+    srcs = [
+         "log/trt_logger.cc",
+    ],
+    hdrs=[
+         "log/trt_logger.h",
+    ],
+    deps=[
+        "@local_config_tensorrt//:tensorrt",
+        "//tensorflow/core:lib_proto_parsing",
+    ],
+    visibility = ["//visibility:public"],
+)
+
+tf_gen_op_wrapper_py(
+    name = "trt_engine_op",
+    deps = [
+        ":trt_engine_op_op_lib",
+        ":trt_shape_function",
+    ],
+)
+
+
+tf_custom_op_py_library(
+    name = "trt_engine_op_loader",
+    srcs = ["python/ops/trt_engine_op.py"],
+    dso = [":python/ops/_trt_engine_op.so",
+           "@local_config_tensorrt//:tensorrt",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:resources",
+    ],
+)
+
+py_library(
+    name = "init_py",
+    srcs = [
+        "__init__.py",
+        "python/__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":trt_ops_py",
+        ":trt_convert_py",
+        
+    ],
+)
+
+py_library(
+    name="trt_ops_py",
+    srcs_version = "PY2AND3",
+    deps=[":trt_engine_op",
+          ":trt_engine_op_loader",
+    ],
+    
+)
+
+py_library(
+    name="trt_convert_py",
+    srcs=["python/trt_convert.py"],
+    srcs_version = "PY2AND3",
+    deps=[
+        ":wrap_conversion"
+    ],
+)
+
+tf_py_wrap_cc(
+    name="wrap_conversion",
+    srcs=["trt_conversion.i"],
+    deps=[
+        ":trt_conversion",
+        "//tensorflow/core:framework_lite",
+        "//util/python:python_headers",
+    ],
+)
+
+cc_library(
+    name= "trt_conversion",
+    srcs=[
+        "convert/convert_nodes.cc",
+        "convert/convert_graph.cc",
+        "segment/segment.cc",
+        "convert/inferShapes.cc",
+    ],
+    hdrs=[
+        "convert/convert_nodes.h",
+        "convert/convert_graph.h",
+        "convert/inferShapes.h",
+        "segment/segment.h",
+        "segment/union_find.h",
+    ],
+    deps=[
+        "@local_config_tensorrt//:tensorrt",
+        "@protobuf_archive//:protobuf_headers",
+        "@nsync//:nsync_headers",
+        ":trt_logging",
+        "//tensorflow/core:framework_lite",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:core_cpu_base",
+        #"//third_party/eigen3",
+    ],
+)
+
+tf_custom_op_library(
+    name = "tensorrt_ops.so",
+    srcs = [
+        "ops/tensorrt_ops.cc",
+    ],
+    deps = [
+        "@local_config_tensorrt//:tensorrt",
+    ],
+)
+
+
+# Library for the segmenting portion of TensorRT operation creation
+cc_library(
+    name = "segment",
+    srcs = [
+        "segment/segment.cc",
+    ],
+    hdrs = [
+        "segment/union_find.h",
+        "segment/segment.h",
+    ],
+    deps = [
+        "@protobuf_archive//:protobuf_headers",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:lib_proto_parsing",
+        "//third_party/eigen3",
+    ],
+    linkstatic = 1,
+)
+
+tf_cc_test(
+    name = "segment_test",
+    size = "small",
+    srcs = ["segment/segment_test.cc"],
+    deps = [
+        ":segment",
+        "//tensorflow/c:c_api",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
+
+
+# Library for the node-level conversion portion of TensorRT operation creation
+
+filegroup(
+    name = "cppfiles",
+    srcs = glob(["**/*.cc"]),
+    visibility=["//visibility:private"],
+)
+
+filegroup(
+    name = "headers",
+    srcs = glob(["**/*.h"]),
+    visibility=["//visibility:private"],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
new file mode 100644
index 0000000..61b348f
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -0,0 +1,42 @@
+Using TensorRT in TensorFlow
+============================
+
+This module provides necessary bindings and introduces TRT_engine_op
+operator that wraps a subgraph in TensorRT.
+
+Compilation
+-----------
+
+In order to compile the module, you need to have a local TensorRT
+installation (libnvinfer.so and respective include files). During the
+configuration step, TensorRT should be enabled and installation path
+should be set. If installed through package managers (deb,rpm),
+configure script should find the necessary components from the system
+automatically. If installed from tar packages, user has to set path to
+location where the library is installed during configuration.
+
+In order to enable TensorRT support, user has to add `--config=tensorrt` to
+the build flags during the compilation such as
+
+```
+bazel build --config=cuda --config=opt --config=tensorrt //tensorflow/tools/pip_package:build_pip_package
+bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
+```
+
+After the installation of tensorflow package, TensorRT transformation
+will be available. An example use is shown below.
+
+```python
+import tensorflow as tf
+import tensorflow.contrib.tensorrt as trt
+#... create and train or load model
+gdef=sess.graph.as_graph_def()
+trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def
+				  ["output"], #name of output node(s)
+				  max_batch_size, #maximum batch size to run the inference
+				  max_workspace_size # max memory for TensorRT to use 
+				  )
+tf.reset_default_graph()
+tf.import_graph_def(graph_def=trt_gdef)
+#...... run inference
+```
diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py
new file mode 100644
index 0000000..0d69ffe
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensorrt.python import *
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
new file mode 100644
index 0000000..29aa555
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -0,0 +1,253 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
+#include <list>
+#include <set>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <map>
+#include <utility>
+
+#include "NvInfer.h"
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+//------------------------------------------------------------------------------
+namespace tensorrt {
+namespace convert {
+
+namespace {
+
+static std::unordered_set<std::string> output_nodes;
+bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
+  static const std::set<std::string> candidate_ops = {
+      "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
+      "Add",      "Mul",   "Sub",    "Rsqrt",   "Pad"  // "Placeholder" ,"Mean"
+                                                       // TODO(ben,jie): ...
+  };
+  if (output_nodes.count(node_def.name())) return false;
+  return candidate_ops.count(node_def.op());
+}
+
+void GetSubGraphIncomingEdges(tensorflow::Graph const& graph,
+                              std::set<int> const& subgraph_node_ids,
+                              tensorflow::EdgeSet* incoming_edges) {
+  for (int node_id : subgraph_node_ids) {
+    tensorflow::Node const* node = graph.FindNodeId(node_id);
+    LOG(DEBUG) << node->name() << " has incoming edges: ";
+    for (tensorflow::Edge const* edge : node->in_edges()) {
+      if (!subgraph_node_ids.count(edge->src()->id()) &&
+          !edge->src()->IsSource()) {
+        LOG(DEBUG) << edge->src()->name() << ", ";
+        incoming_edges->insert(edge);
+      }
+    }
+  }
+}
+
+void GetSubGraphOutgoingEdges(tensorflow::Graph const& graph,
+                              std::set<int> const& subgraph_node_ids,
+                              tensorflow::EdgeSet* outgoing_edges) {
+  for (int node_id : subgraph_node_ids) {
+    tensorflow::Node const* node = graph.FindNodeId(node_id);
+    LOG(DEBUG) << node->name() << " has outgoing edges: ";
+    for (tensorflow::Edge const* edge : node->out_edges()) {
+      if (!subgraph_node_ids.count(edge->dst()->id()) &&
+          !edge->dst()->IsSink()) {
+        outgoing_edges->insert(edge);
+      }
+    }
+  }
+}
+
+std::pair<std::string, int> ParseTensorName(std::string name,
+                                            int default_idx = 0) {
+  int idx = default_idx;
+  size_t sep = name.find_last_of(':');
+  if (sep != std::string::npos) {
+    name = name.substr(0, sep);
+    idx = std::stoi(name.substr(sep + 1));
+  }
+  return std::make_pair(name, idx);
+}
+
+std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
+    const std::vector<std::string>& tensor_names) {
+  std::unordered_map<std::string, std::vector<int>> result;
+  for (std::string const& tensor_name : tensor_names) {
+    std::string node_name;
+    int index;
+    std::tie(node_name, index) = ParseTensorName(tensor_name);
+    result[node_name].push_back(index);
+  }
+  return result;
+}
+
+tensorflow::Status ConvertSubGraphToTensorRT(
+    tensorflow::Graph& graph, const std::vector<std::string>& output_names,
+    const std::set<int>& subgraph_node_ids, size_t max_batch_size,
+    size_t max_workspace_size, const ShapeMap& shape_map) {
+  tensorflow::EdgeSet subgraph_incoming_edges;
+  GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
+
+  std::vector<std::pair<int, int>> subgraph_inputs;
+
+
+  // Collect inputs by looking for incoming edges
+  for (tensorflow::Edge const* edge : subgraph_incoming_edges) {
+    subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
+  }
+  std::set<std::pair<int, int>> subgraph_outputs_set;
+  // Collect outputs referenced from output_names
+  auto output_name_to_index_map = BuildTensorNameMap(output_names);
+  // for (int node_id : subgraph_node_ids_no_placeholder) {
+  for (int node_id : subgraph_node_ids) {
+    tensorflow::Node* node = graph.FindNodeId(node_id);
+    if (output_name_to_index_map.count(node->name())) {
+      for (int index : output_name_to_index_map.at(node->name())) {
+        subgraph_outputs_set.insert({node_id, index});
+      }
+    }
+  }
+  // Collect outputs referenced from outgoing edges
+  tensorflow::EdgeSet subgraph_outgoing_edges;
+  // GetSubGraphOutgoingEdges(graph, subgraph_node_ids_no_placeholder,
+  //  &subgraph_outgoing_edges);
+  GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
+  for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
+    subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
+  }
+  // Impose an ordering on the outputs
+  std::vector<std::pair<int, int>> subgraph_outputs(
+      subgraph_outputs_set.begin(), subgraph_outputs_set.end());
+  // Build TensorRT node and add it to the graph
+  tensorflow::NodeDef trt_node_def;
+  TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
+      graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
+      max_batch_size, max_workspace_size, shape_map, &trt_node_def));
+  tensorflow::Status status;
+  tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
+
+  TF_RETURN_IF_ERROR(status);
+
+  // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
+  std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
+  for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
+    subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
+  }
+  TF_RETURN_IF_ERROR(status);
+  for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
+    std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+    int new_src_output = subgraph_edge_to_output_map.at(old_src);
+    graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
+  }
+  // Remove the original subgraph
+  for (int node_id : subgraph_node_ids) {
+    tensorflow::Node* node = graph.FindNodeId(node_id);
+    // Don't remove the input placeholders
+    if (node->type_string() == "Placeholder") {
+      continue;
+    }
+    graph.RemoveNode(node);
+  }
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status BuildNodeMap(
+    const tensorflow::Graph& graph,
+    std::unordered_map<std::string, tensorflow::Node*>* node_map) {
+  for (auto* node : graph.op_nodes()) {
+    if (!node_map->insert({node->name(), node}).second) {
+      return tensorflow::errors::AlreadyExists(
+          "Node name is not unique in graph: " + node->name());
+    }
+  }
+  return tensorflow::Status::OK();
+}
+
+}  // namespace
+
+tensorflow::Status ConvertGraphDefToTensorRT(
+    const tensorflow::GraphDef& graph_def,
+    const std::vector<std::string>& output_names, size_t max_batch_size,
+    size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
+  ShapeMap shape_map;
+  TF_RETURN_IF_ERROR(
+      tensorflow::trt::inferShapes(graph_def, output_names, shape_map));
+  std::stringstream oss;
+  for (auto& n : shape_map) {  // nodes
+    oss << " Node= " << n.first << ", ";
+    for (auto o : n.second) {  // outputs
+      oss << o.first.DebugString() << " T= " << o.second << ", ";
+    }
+    LOG(DEBUG) << oss.str();
+    oss.str("");
+  }
+  // Build full graph
+  tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+                                             graph_def.library());
+  tensorflow::Graph graph(flib);
+  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+      tensorflow::GraphConstructorOptions(), graph_def, &graph));
+
+  // Segment the graph into subgraphs that can be converted to TensorRT
+  tensorrt::segment::SegmentOptions segment_options;
+  // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
+  for (auto node : output_names) output_nodes.insert(node);
+
+  // TODO(sami): this should be passed as a knob!!!!
+  segment_options.minimum_segment_size = 2;
+  tensorrt::segment::SegmentNodesVector segments;
+  TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
+      graph_def, IsTensorRTCandidate, segment_options, &segments));
+  if (segments.size() > 1) {
+    // LOG(WARNING) << "Multiple TensorRT candidate subgraphs were found, "
+    //<< "but only the first can be converted.";
+    // segments.erase(++segments.begin(), segments.end());
+    LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
+  }
+  std::unordered_map<std::string, tensorflow::Node*> node_map;
+  TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
+  for (std::set<std::string> const& subgraph_node_names : segments) {
+    std::set<int> subgraph_node_ids;
+    for (std::string const& node_name : subgraph_node_names) {
+      subgraph_node_ids.insert(node_map.at(node_name)->id());
+    }
+    TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
+        graph, output_names, subgraph_node_ids, max_batch_size,
+        max_workspace_size, shape_map));
+  }
+  graph.ToGraphDef(new_graph_def);
+  return tensorflow::Status::OK();
+}
+
+}  // namespace convert
+}  // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
new file mode 100644
index 0000000..cd713de
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorrt {
+namespace convert {
+
+tensorflow::Status ConvertGraphDefToTensorRT(
+    const tensorflow::GraphDef& graph_def,
+    const std::vector<std::string>& output_names, size_t max_batch_size,
+    size_t max_workspace_size, tensorflow::GraphDef* new_graph_def);
+}
+}  // namespace tensorrt
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
new file mode 100644
index 0000000..03146b1
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -0,0 +1,1737 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+
+#include <algorithm>
+#include <fstream>
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "NvInfer.h"
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+//  Check if the types are equal. Cast to int first so that failure log message
+//  would work!
+#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+//------------------------------------------------------------------------------
+namespace tensorrt {
+namespace convert {
+
+namespace {
+
+inline int get_dtype_size(nvinfer1::DataType trt_dtype) {
+  switch (trt_dtype) {
+    case nvinfer1::DataType::kFLOAT:
+      return 4;
+    case nvinfer1::DataType::kINT8:
+      return 1;
+    case nvinfer1::DataType::kHALF:
+      return 2;
+    default:
+      return -1;
+  }
+}
+
+inline int get_dtype_size(tensorflow::DataType trt_dtype) {
+  switch (trt_dtype) {
+    case tensorflow::DataType::DT_FLOAT:
+      return 4;
+    case tensorflow::DataType::DT_INT8:
+      return 1;
+    case tensorflow::DataType::DT_HALF:
+      return 2;
+    case tensorflow::DataType::DT_INT32:
+      return 4;
+    default:
+      return -1;
+  }
+}
+
+inline tensorflow::Status convert_dtype(tensorflow::DataType tf_dtype,
+                                        nvinfer1::DataType* trt_dtype) {
+  switch (tf_dtype) {
+    case tensorflow::DataType::DT_FLOAT:
+      *trt_dtype = nvinfer1::DataType::kFLOAT;
+      break;
+    case tensorflow::DataType::DT_INT8:
+      *trt_dtype = nvinfer1::DataType::kINT8;
+      break;
+    case tensorflow::DataType::DT_HALF:
+      *trt_dtype = nvinfer1::DataType::kHALF;
+      break;
+    default:
+      return tensorflow::errors::InvalidArgument("Unsupported data type");
+  }
+  return tensorflow::Status::OK();
+}
+
+inline nvinfer1::Dims get_tensor_shape(const tensorflow::Tensor& tensor) {
+  nvinfer1::Dims dims;
+  dims.nbDims = tensor.dims();
+  for (int i = 0; i < dims.nbDims; i++) {
+    dims.d[i] = tensor.dim_size(i);
+  }
+  return dims;
+}
+
+inline int64_t get_shape_size(nvinfer1::Dims shape) {
+  // Returns total number of elements in shape
+  int64_t count = 1;
+  for (int d = 0; d < shape.nbDims; ++d) {
+    count *= shape.d[d];
+  }
+  return count;
+}
+
+static std::vector<std::pair<int, int>> createSamePadding(
+    nvinfer1::DimsHW& stride, nvinfer1::DimsHW& kernel,
+    std::vector<int64_t> inputDims) {
+  std::vector<std::pair<int, int>> padding(inputDims.size());
+  CHECK_EQ((size_t)stride.nbDims, inputDims.size());  // TODO(jie): N+C? NC+?
+
+  for (size_t i = 0; i < inputDims.size(); ++i) {
+    /* formula to calculate the padding */
+    int p = ((inputDims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
+            inputDims[i];
+    p = (p > 0) ? p : 0;
+
+    /* right precedence padding, like in TensorFlow */
+    int left = p / 2;
+    int right = p - left;
+
+    padding[i] = {left, right};
+  }
+  return padding;
+}
+
+// class TRT_ShapedWeights : public nvinfer1::Weights {
+class TRT_ShapedWeights {
+ public:
+  nvinfer1::Dims shape_;
+  tensorflow::DataType type_;
+  const void* values_;
+  bool dummy_flag_;
+  int64_t count() const {
+    int64_t c = 1;
+    for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
+    return c;
+  }
+  TRT_ShapedWeights(tensorflow::DataType type, const void* values,
+                    nvinfer1::Dims shape)
+      : shape_(shape), type_(type), values_(values), dummy_flag_(false) {
+    // Note: this->shape.type[] is not used
+  }
+  explicit TRT_ShapedWeights(tensorflow::DataType type)
+      : type_(type), values_(nullptr), dummy_flag_(true) {}
+  nvinfer1::Weights getWeightsForTRT() const {
+    nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
+    TF_CHECK_OK(convert_dtype(type_, &trt_type));
+    if (dummy_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
+
+    // Note: this->shape.type[] is not used
+    return nvinfer1::Weights{trt_type, values_, get_shape_size(shape_)};
+  }
+  size_t size_bytes() const {
+    return this->count() * get_dtype_size(this->type_);
+  }
+  // default converter
+  operator nvinfer1::Weights() const { return getWeightsForTRT(); }
+};
+
+class TRT_TensorOrWeights {
+  union {
+    nvinfer1::ITensor* _tensor_;
+    TRT_ShapedWeights _weights_;
+  };
+  enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } _variant_;
+
+ public:
+  explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
+      : _tensor_(tensor), _variant_(TRT_NODE_TENSOR) {}
+  explicit TRT_TensorOrWeights(TRT_ShapedWeights const& weights)
+      : _weights_(weights), _variant_(TRT_NODE_WEIGHTS) {}
+  TRT_TensorOrWeights() = delete;
+  bool is_tensor() const { return _variant_ == TRT_NODE_TENSOR; }
+  bool is_weights() const { return _variant_ == TRT_NODE_WEIGHTS; }
+  nvinfer1::ITensor* tensor() {
+    CHECK_EQ(this->is_tensor(), true);
+    return _tensor_;
+  }
+  nvinfer1::ITensor const* tensor() const {
+    CHECK_EQ(this->is_tensor(), true);
+    return _tensor_;
+  }
+  TRT_ShapedWeights& weights() {
+    CHECK_EQ(this->is_weights(), true);
+    return _weights_;
+  }
+  TRT_ShapedWeights const& weights() const {
+    CHECK_EQ(this->is_weights(), true);
+    return _weights_;
+  }
+  nvinfer1::Dims shape() const {
+    if (this->is_tensor()) {
+      return this->tensor()->getDimensions();
+    } else {
+      return this->weights().shape_;
+    }
+  }
+};
+
+class TRT_LayerOrWeights {
+  union {
+    nvinfer1::ILayer* _layer_;
+    TRT_ShapedWeights _weights_;
+  };
+  enum { TRT_NODE_LAYER, TRT_NODE_WEIGHTS } _variant_;
+
+ public:
+  explicit TRT_LayerOrWeights(nvinfer1::ILayer* layer)
+      : _layer_(layer), _variant_(TRT_NODE_LAYER) {}
+  explicit TRT_LayerOrWeights(TRT_ShapedWeights const& weights)
+      : _weights_(weights), _variant_(TRT_NODE_WEIGHTS) {}
+  bool is_layer() const { return _variant_ == TRT_NODE_LAYER; }
+  bool is_weights() const { return _variant_ == TRT_NODE_WEIGHTS; }
+  nvinfer1::ILayer* layer() {
+    CHECK_EQ(this->is_layer(), true);
+    return _layer_;
+  }
+  TRT_ShapedWeights& weights() {
+    CHECK_EQ(this->is_weights(), true);
+    return _weights_;
+  }
+  TRT_TensorOrWeights output(int index = 0) const {
+    if (this->is_layer()) {
+      nvinfer1::ITensor* tensor = _layer_->getOutput(index);
+      return TRT_TensorOrWeights(tensor);
+    } else {
+      CHECK_EQ(index, 0);
+      return TRT_TensorOrWeights(_weights_);
+    }
+  }
+};
+
+class TFAttrs {
+  typedef std::map<std::string, tensorflow::AttrValue const*> AttrMap;
+  AttrMap _attrs;
+
+ public:
+  explicit TFAttrs(tensorflow::NodeDef const& tf_node) {
+    for (auto const& attr : tf_node.attr()) {
+      _attrs.insert({attr.first, &attr.second});
+    }
+  }
+  bool count(std::string key) const { return _attrs.count(key); }
+  tensorflow::AttrValue const* at(std::string key) const {
+    if (!_attrs.count(key)) {
+      throw std::out_of_range("Attribute not found: " + key);
+    }
+    return _attrs.at(key);
+  }
+  template <typename T>
+  T get(std::string key) const;
+  template <typename T>
+  T getShape(std::string key) const;
+  template <typename T>
+  T get(std::string key, T const& default_value) const {
+    return _attrs.count(key) ? this->get<T>(key) : default_value;
+  }
+};
+// template <>
+// float TFAttrs::get<float>(std::string key) const {
+//  return this->at(key)->f();
+//}
+
+// template <>
+// int TFAttrs::get<int>(std::string key) const {
+//  return (int)this->at(key)->i();
+//}
+
+// template <>
+// bool TFAttrs::get<bool>(std::string key) const {
+//  auto value = this->at(key)->i();
+//  return bool(value);
+//}
+
+template <>
+std::string TFAttrs::get<std::string>(std::string key) const {
+  return this->at(key)->s();
+}
+template <>
+std::vector<int> TFAttrs::get<std::vector<int>>(std::string key) const {
+  auto attr = this->at(key)->list().i();
+  return std::vector<int>(attr.begin(), attr.end());
+}
+template <>
+nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(std::string key) const {
+  auto values = this->get<std::vector<int>>(key);
+  nvinfer1::Dims dims;
+  dims.nbDims = values.size();
+  std::copy(values.begin(), values.end(), dims.d);
+  // Note: No dimension type information is included
+  return dims;
+}
+// template <>
+// nvinfer1::DimsHW TFAttrs::get<nvinfer1::DimsHW>(std::string key) const {
+//  nvinfer1::Dims dims = this->get<nvinfer1::Dims>(key);
+//  CHECK_EQ(dims.nbDims, 2);
+//  return nvinfer1::DimsHW(dims.d[0], dims.d[1]);
+//}
+// template <>
+// nvinfer1::Permutation TFAttrs::get<nvinfer1::Permutation>(
+//    std::string key) const {
+//  auto values = this->get<std::vector<int>>(key);
+//  nvinfer1::Permutation perm;
+//  std::copy(values.begin(), values.end(), perm.order);
+//  // Fill unused values with -1 to aid debugging
+//  std::fill(perm.order + values.size(), perm.order + nvinfer1::Dims::MAX_DIMS,
+//            -1);
+//  return perm;
+//}
+// template <>
+// nvinfer1::Dims TFAttrs::getShape<nvinfer1::Dims>(std::string key) const {
+//  auto attr = this->at(key)->shape();
+//  nvinfer1::Dims dims;
+//  dims.nbDims = attr.dim_size();
+//  for (int i = 0; i < dims.nbDims; i++) dims.d[i] = attr.dim(i).size();
+//  return dims;
+//}
+// template<> TRT_ShapedWeights TFAttrs::get<TRT_ShapedWeights>(std::string key)
+// const {
+//  tensorflow::TensorProto const* tf_weights_tensor = &this->at(key)->tensor();
+// TODO(jie): Implement this
+//  return convert_tf_weights(tf_weights_tensor);
+//}
+template <>
+nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(std::string key) const {
+  nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
+  TF_CHECK_OK(convert_dtype(this->at(key)->type(), &trt_dtype));
+  return trt_dtype;
+}
+template <>
+tensorflow::DataType TFAttrs::get<tensorflow::DataType>(std::string key) const {
+  return this->at(key)->type();
+}
+
+template <typename T>
+void reorder4(nvinfer1::DimsNCHW shape, T const* idata,
+              nvinfer1::DimsNCHW istrides, T* odata,
+              nvinfer1::DimsNCHW ostrides) {
+  for (int n = 0; n < shape.n(); ++n) {
+    for (int c = 0; c < shape.c(); ++c) {
+      for (int h = 0; h < shape.h(); ++h) {
+        for (int w = 0; w < shape.w(); ++w) {
+          odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
+                w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
+                                          h * istrides.h() + w * istrides.w()];
+        }
+      }
+    }
+  }
+}
+
+void reorder_rsck_to_kcrs(TRT_ShapedWeights const& iweights,
+                          TRT_ShapedWeights* oweights) {
+  CHECK_EQ(iweights.type_, oweights->type_);
+  CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
+  int r = iweights.shape_.d[0];
+  int s = iweights.shape_.d[1];
+  int c = iweights.shape_.d[2];
+  int k = iweights.shape_.d[3];
+  oweights->shape_.d[0] = k;
+  oweights->shape_.d[1] = c;
+  oweights->shape_.d[2] = r;
+  oweights->shape_.d[3] = s;
+  // nvinfer1::DimsNCHW istrides = {1, s, c*r*s, r*s};
+  nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
+  nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
+  switch (iweights.type_) {
+    case tensorflow::DataType::DT_FLOAT:
+      reorder4(
+          {k, c, r, s}, static_cast<float const*>(iweights.values_), istrides,
+          static_cast<float*>(const_cast<void*>(oweights->values_)), ostrides);
+      break;
+    default:
+      LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
+  }
+}
+
+/* not used. clean up needed.
+nvinfer1::Weights make_dummy_weights(nvinfer1::DataType
+dtype=nvinfer1::DataType::kFLOAT) { nvinfer1::Weights w; w.count  = 0; w.values
+= nullptr; w.type   = dtype; return w;
+}
+*/
+
+struct InferDeleter {
+  template <typename T>
+  void operator()(T* obj) const {
+    if (obj) {
+      obj->destroy();
+    }
+  }
+};
+
+template <typename T>
+inline std::shared_ptr<T> infer_object(T* obj) {
+  return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+// Logger for GIE info/warning/errors
+class Converter;
+
+using OpConverter =
+    std::function<tensorflow::Status(Converter&, tensorflow::NodeDef const&,
+                                     std::vector<TRT_TensorOrWeights> const&,
+                                     std::vector<TRT_TensorOrWeights>*)>;
+
+class Converter {
+  std::unordered_map<std::string, TRT_TensorOrWeights> _trt_tensors;
+  std::unordered_map<std::string, OpConverter> _op_registry;
+  nvinfer1::INetworkDefinition* _trt_network;
+  std::list<std::vector<uint8_t>> _temp_bufs;
+
+  void register_op_converters();
+
+  std::vector<TRT_TensorOrWeights> get_inputs(
+      tensorflow::NodeDef const& node_def) {
+    std::vector<TRT_TensorOrWeights> inputs;
+    for (auto const& input_name : node_def.input()) {
+      LOG(DEBUG) << "retrieve input: " << input_name;
+      inputs.push_back(_trt_tensors.at(input_name));
+    }
+    return inputs;
+  }
+
+ public:
+  explicit Converter(nvinfer1::INetworkDefinition* trt_network)
+      : _trt_network(trt_network) {
+    this->register_op_converters();
+  }
+
+  TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
+                                     nvinfer1::Dims shape) {
+    TRT_ShapedWeights weights(type, nullptr, shape);
+    _temp_bufs.push_back(std::vector<uint8_t>(weights.size_bytes()));
+    weights.values_ = _temp_bufs.back().data();
+    return weights;
+  }
+
+  TRT_ShapedWeights get_temp_weights_like(TRT_ShapedWeights const& weights) {
+    return this->get_temp_weights(weights.type_, weights.shape_);
+  }
+
+  tensorflow::Status convert_node(tensorflow::NodeDef const& node_def) {
+    std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
+    std::string op = node_def.op();
+    if (!_op_registry.count(op)) {
+      return tensorflow::errors::Unimplemented(
+          "no converter registered for op: " + op);
+    }
+    OpConverter op_converter = _op_registry.at(op);
+    std::vector<TRT_TensorOrWeights> outputs;
+    TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      TRT_TensorOrWeights output = outputs.at(i);
+      // TODO(jie): tf protobuf seems to be omitting the :0 suffix
+      std::string output_name = node_def.name();
+      if (i != 0) output_name = output_name + ":" + std::to_string(i);
+      if (output.is_tensor()) {
+        output.tensor()->setName(output_name.c_str());
+      }
+      LOG(DEBUG) << "write out tensor: " << output_name;
+      if (!_trt_tensors.insert({output_name, output}).second) {
+        return tensorflow::errors::AlreadyExists(
+            "output tensor already exists for op: " + op);
+      }
+    }
+    return tensorflow::Status::OK();
+  }
+
+  nvinfer1::INetworkDefinition* network() { return _trt_network; }
+
+  TRT_TensorOrWeights get_tensor(std::string name) {
+    if (!_trt_tensors.count(name)) {
+      return TRT_TensorOrWeights(nullptr);
+    }
+    return _trt_tensors.at(name);
+  }
+
+  bool insert_input_tensor(std::string name, nvinfer1::ITensor* tensor) {
+    return _trt_tensors.insert({name, TRT_TensorOrWeights(tensor)}).second;
+  }
+
+  nvinfer1::ITensor* transposeTensor(nvinfer1::ITensor* input_tensor,
+                                     std::vector<int> order) {
+    auto dims = input_tensor->getDimensions();
+
+    // TODO(jie): change the return to status and properly exit
+    if (order.size() - 1 != size_t(dims.nbDims))
+      LOG(ERROR) << "dimension does not match, fail gracefully";
+
+    nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
+    nvinfer1::Permutation permutation;
+    for (int32_t i = 0; i < dims.nbDims; ++i) {
+      permutation.order[i] = order[i + 1] - 1;
+    }
+    layer->setFirstTranspose(permutation);
+
+    nvinfer1::Dims reshapeDims;
+    reshapeDims.nbDims = dims.nbDims;
+    for (int32_t i = 0; i < reshapeDims.nbDims; ++i) {
+      reshapeDims.d[i] = 0;
+      reshapeDims.type[i] = dims.type[i];
+    }
+    layer->setReshapeDimensions(reshapeDims);
+    return layer->getOutput(0);
+  }
+};
+
+/*******************************************************************************
+  Constant folding functions
+  TODO(jie): once optimizer kicks in, we should have done constant folding
+there.
+*******************************************************************************/
+struct LambdaFactory {
+  enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
+  OP_CATEGORY op;
+
+  template <typename T>
+  std::function<T(T)> unary() {
+    switch (op) {
+      case OP_CATEGORY::RSQRT: {
+        LOG(DEBUG) << "RSQRT GETS DONE";
+        return [](T t) -> T { return 1.0 / std::sqrt(t); };
+      }
+      case OP_CATEGORY::NEG:
+        return [](T t) -> T { return -t; };
+      default:
+        LOG(DEBUG) << "not supported op for unary: " << static_cast<int>(op);
+        return nullptr;
+    }
+  }
+
+  template <typename T>
+  std::function<T(T, T)> binary() {
+    switch (op) {
+      case OP_CATEGORY::ADD:
+        return [](T l, T r) -> T { return l + r; };
+      case OP_CATEGORY::SUB:
+        return [](T l, T r) -> T { return l - r; };
+      case OP_CATEGORY::MUL:
+        return [](T l, T r) -> T { return l * r; };
+      default:
+        LOG(WARNING) << "not supported op for binary: " << static_cast<int>(op);
+    }
+    return [](T l, T r) -> T {
+      LOG(FATAL) << "Unsupported op type ";
+      return l;
+    };
+  }
+
+  template <typename T>
+  std::function<T(T)> broadcast_r(T val) {
+    LOG(DEBUG) << "LAMBDA VAL : " << val;
+    switch (op) {
+      case OP_CATEGORY::ADD:
+        return [val](T l) -> T {
+          LOG(DEBUG) << "LAMBDA VAL : " << val;
+          return l + val;
+        };
+        // return [val](T l)-> T {return l+val;};
+      case OP_CATEGORY::SUB:
+        return [val](T l) -> T {
+          LOG(DEBUG) << "LAMBDA VAL : " << val;
+          return l - val;
+        };
+      case OP_CATEGORY::MUL:
+        return [val](T l) -> T {
+          LOG(DEBUG) << "LAMBDA VAL : " << val;
+          return l * val;
+        };
+      default:
+        LOG(WARNING) << "not supported op for binary: " << static_cast<int>(op);
+    }
+    return [val](T l) -> T {
+      LOG(FATAL) << "Unsupported op type ";
+      return l;
+    };
+  }
+
+  template <typename T>
+  std::function<T(T)> broadcast_l(T val) {
+    LOG(DEBUG) << "LAMBDA VAL : " << val;
+    switch (op) {
+      case OP_CATEGORY::ADD:
+        return [val](T l) -> T {
+          LOG(DEBUG) << "LAMBDA VAL : " << val;
+          return val + l;
+        };
+      case OP_CATEGORY::SUB:
+        return [val](T l) -> T {
+          LOG(DEBUG) << "LAMBDA VAL : " << val;
+          return val - l;
+        };
+      case OP_CATEGORY::MUL:
+        return [val](T l) -> T {
+          LOG(DEBUG) << "LAMBDA VAL : " << val;
+          return val * l;
+        };
+      default:
+        LOG(ERROR) << "not supported op for binary: " << static_cast<int>(op);
+    }
+    return [val](T l) -> T {
+      LOG(FATAL) << "Unsupported op type ";
+      return l;
+    };
+  }
+};
+
+tensorflow::Status UnaryCompute(TRT_ShapedWeights const& iweights,
+                                TRT_ShapedWeights* oweights,
+                                LambdaFactory unary_op) {
+  // assume iweights.type == oweights.type
+  CHECK_EQ(iweights.type_, oweights->type_);
+
+  switch (iweights.type_) {
+    case tensorflow::DataType::DT_FLOAT: {
+      auto inp = static_cast<float const*>(iweights.values_);
+      auto oup = static_cast<float*>(const_cast<void*>(oweights->values_));
+      std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
+      break;
+    }
+    default:
+      return tensorflow::errors::Unimplemented("data type not supported: " +
+                                               iweights.type_);
+  }
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryCompute(TRT_ShapedWeights const& iweights_l,
+                                 TRT_ShapedWeights const& iweights_r,
+                                 TRT_ShapedWeights* oweights,
+                                 LambdaFactory binary_op) {
+  // assume iweights_l.type == iweight_r.type
+  CHECK_EQ(iweights_l.type_, oweights->type_);
+  CHECK_EQ(iweights_r.type_, oweights->type_);
+  LOG(DEBUG) << "SANITY CHECK!";
+
+  switch (iweights_l.type_) {
+    case tensorflow::DataType::DT_FLOAT: {
+      auto inp_l = static_cast<float const*>(iweights_l.values_);
+      auto inp_r = static_cast<float const*>(iweights_r.values_);
+      auto oup = static_cast<float*>(const_cast<void*>(oweights->values_));
+
+      if (iweights_l.count() != iweights_r.count()) {
+        // we only supports broadcast of RankZero
+        if (iweights_l.count() == 1) {
+          LOG(DEBUG) << "I bet it is not working!" << (*inp_l);
+          std::transform(inp_r, inp_r + iweights_r.count(), oup,
+                         binary_op.broadcast_l<float>(*inp_l));
+        } else if (iweights_r.count() == 1) {
+          LOG(DEBUG) << "I bet it is not working!" << (*inp_r);
+          std::transform(inp_l, inp_l + iweights_l.count(), oup,
+                         binary_op.broadcast_r<float>(*inp_r));
+        } else {
+          return tensorflow::errors::Unimplemented(
+              "Binary op with non-rankZero broadcast not supported");
+        }
+      } else {
+        std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
+                       binary_op.binary<float>());
+      }
+      break;
+    }
+    default:
+      return tensorflow::errors::Unimplemented("data type not supported: " +
+                                               iweights_l.type_);
+  }
+
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConstantFoldUnary(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    std::vector<TRT_TensorOrWeights> const& inputs,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  TRT_ShapedWeights weights_input = inputs.at(0).weights();
+
+  // allocate output weights
+  TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
+
+  // FIXME assume type matches input weights
+  // get trt type & shape
+  // maybe this part has to be moved into the block of rsqrt later
+  // check type consistency
+  CHECK_EQ(weights_input.type_,
+           TFAttrs(node_def).get<tensorflow::DataType>("T"));
+
+  // Maybe I should do a switch
+  LambdaFactory unary_op;
+  if (node_def.op() == "Rsqrt") {
+    // compute rsqrt
+    unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
+    auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
+    // pass the output
+    if (ret == tensorflow::Status::OK()) {
+      outputs->push_back(TRT_TensorOrWeights(weights_output));
+    }
+    return ret;
+  } else {
+    return tensorflow::errors::Unimplemented("Binary op not supported: " +
+                                             node_def.op());
+  }
+}
+
+// TODO(jie,ben) broadcast is needed yet not implemented
+// Let's get the simple stuff working first. Maybe we should fall bakc to TF
+//   approach for constant folding
+tensorflow::Status ConstantFoldBinary(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    std::vector<TRT_TensorOrWeights> const& inputs,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
+  TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
+
+  // check type consistency
+  CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
+
+  if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
+    return tensorflow::errors::Unimplemented(
+        "Binary op implicit broadcast not supported: " + node_def.op());
+
+  // TODO(jie): constant fold should really fall back to TF.
+  int nbDims = weights_input_l.shape_.nbDims;
+  nvinfer1::Dims output_shape;
+  output_shape.nbDims = nbDims;
+  LOG(DEBUG) << "nbDims: " << nbDims
+             << "the other: " << weights_input_r.shape_.nbDims;
+  for (int i = 0; i < nbDims; i++) {
+    if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
+      output_shape.d[i] = weights_input_l.shape_.d[i];
+    } else if (weights_input_l.shape_.d[i] == 1 ||
+               weights_input_r.shape_.d[i] == 1) {
+      output_shape.d[i] =
+          std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
+    } else {
+      return tensorflow::errors::Unimplemented(
+          "Binary op with incompatible shape at, " + node_def.op());
+    }
+    LOG(DEBUG) << "left: " << weights_input_l.shape_.d[i]
+               << "right: " << weights_input_r.shape_.d[i]
+               << "output: " << output_shape.d[i];
+  }
+
+  // FIXME assume type matches input weights
+  // get trt type & shape
+  TFAttrs attrs(node_def);
+  // maybe this part has to be moved into the block of rsqrt later
+  tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
+
+  // allocate output weights
+  TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
+
+  // Maybe I should do a switch
+  LambdaFactory binary_op;
+  if (node_def.op() == "Sub") {
+    binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
+  } else if (node_def.op() == "Mul") {
+    binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
+  } else if (node_def.op() == "Add") {
+    binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
+  } else {
+    return tensorflow::errors::Unimplemented("Binary op not supported: " +
+                                             node_def.op());
+  }
+  auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
+                           binary_op);
+
+  // pass the output
+  if (ret == tensorflow::Status::OK()) {
+    outputs->push_back(TRT_TensorOrWeights(weights_output));
+  }
+
+  return ret;
+}
+
+// TODO(jie): broadcast is needed yet not implemented
+// only implemented channel wise for the time being
+tensorflow::Status BinaryTensorOpWeight(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  // FIXME assume type matches input weights
+  // get trt type & shape
+  // maybe this part has to be moved into the block of rsqrt later
+
+  // check type consistency
+  auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
+  CHECK_EQ_TYPE(tensor->getType(), dtype);  // cast to int for error messages
+  nvinfer1::DataType ttype;
+  TF_CHECK_OK(convert_dtype(weights.type_, &ttype));
+  CHECK_EQ_TYPE(ttype, dtype);  // cast to int for error message
+
+  // check scale mode
+  auto dims_w = weights.shape_;
+  auto dims_t = tensor->getDimensions();
+
+  // default to channel-wise
+  auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+
+  /*
+  if (weights.count() == 1) {
+    LOG(DEBUG) << "UNIFORM";
+    scale_mode = nvinfer1::ScaleMode::kUNIFORM;
+  } else if (dims_w.nbDims == 1) {
+    // TODO(jie): should we check for implicit chennel wise binary op
+    //   where weights has shape 1x1xC?
+    LOG(DEBUG) << "CHANNEL";
+    scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+  } else {
+    // TODO(jie): check weight shape.
+    // broadcast is not fully supported
+    LOG(DEBUG) << "ELEMENTWISE";
+    scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+  } */
+
+  if (weights.count() == 1) {
+    LOG(DEBUG) << "UNIFORM";
+    scale_mode = nvinfer1::ScaleMode::kUNIFORM;
+  } else {
+    // no broadcasting on Batch dimension;
+    assert(dims_w.d[0]==1);
+
+    // broadcasting on Channel dimension only allowed in kUNIFORM
+    assert(dims_w.d[1]==dims_t.d[0]);
+    assert(dims_w.nbDims==dims_t.nbDims);
+
+    // default is element;
+    for (int i=2; i<dims_w.nbDims; i++) {
+      if (dims_w.d[i]!=dims_t.d[i-1]) {
+        scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+        break;
+      }
+    }
+    if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
+      scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+      for (int i=2; i<dims_w.nbDims; i++) {
+        if (dims_w.d[i]!=1)
+          return tensorflow::errors::InvalidArgument(
+                   "Weight shape not compatible at, " + node_def.name());
+      }
+    }
+  }
+
+  // transpose last dimension
+  /*
+  std::vector<int> permutation(dims_t.nbDims + 1);
+  if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) {
+    // we swap the last dimension into channel for trt.
+    // because of tensorflow default broadcasting rules.
+    for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
+      permutation[i] = i;
+    }
+    permutation[1] = dims_t.nbDims;
+    permutation[dims_t.nbDims] = 1;
+    tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+                                 permutation);
+  }
+  */
+
+  // prepare weights
+  TRT_ShapedWeights shiftWeights(weights.type_);
+  TRT_ShapedWeights scaleWeights(weights.type_);
+  TRT_ShapedWeights powerWeights(weights.type_);
+
+  // Maybe I should do a switch
+  if (node_def.op() == "Sub") {
+    TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
+    LambdaFactory unary_op;
+    unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
+    UnaryCompute(weights, &neg_weights, unary_op);
+    shiftWeights = neg_weights;
+  } else if (node_def.op() == "Mul") {
+    scaleWeights = weights;
+  } else if (node_def.op() == "Add") {
+    shiftWeights = weights;
+  } else {
+    return tensorflow::errors::Unimplemented("Binary op not supported: " +
+                                             node_def.op());
+  }
+
+  nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+      *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shiftWeights,
+      scaleWeights, powerWeights);
+
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+  // transpose back dimension
+  /*
+  if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) {
+    output_tensor = ctx.transposeTensor(output_tensor, permutation);
+  }
+  */
+
+  // pass the output
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryTensorOpTensor(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  static const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
+      ops{
+          {"Add", nvinfer1::ElementWiseOperation::kSUM},
+          {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+          // {"max", nvinfer1::ElementWiseOperation::kMAX},
+          // {"min", nvinfer1::ElementWiseOperation::kMIN},
+          {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+          {"Div", nvinfer1::ElementWiseOperation::kDIV},
+      };
+
+  // FIXME assume type matches input weights
+  // get trt type & shape
+  TFAttrs attrs(node_def);
+  // maybe this part has to be moved into the block of rsqrt later
+  nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+
+  // check type consistency
+  CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+  CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+  auto op_pair = ops.find(node_def.op());
+  if (op_pair == ops.end())
+    return tensorflow::errors::Unimplemented(
+        "binary op: " + node_def.op() +
+        " not supported at: " + node_def.name());
+
+  nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+      *const_cast<nvinfer1::ITensor*>(tensor_l),
+      *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+  // pass the output
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPlaceholder(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    std::vector<TRT_TensorOrWeights> const& inputs,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  LOG(DEBUG) << "Placeholder should have been replace already";
+  return tensorflow::errors::Unimplemented("cannot convert Placeholder op");
+  // OK this make sense since we are supposed to replace it with input
+  TFAttrs attrs(node_def);
+  nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+  nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
+
+  dims.nbDims--;
+  for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
+
+  nvinfer1::ITensor* output =
+      ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
+  if (!output) {
+    return tensorflow::errors::InvalidArgument("Failed to create Input layer");
+  }
+  outputs->push_back(TRT_TensorOrWeights(output));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConv2D(Converter& ctx,
+                                 tensorflow::NodeDef const& node_def,
+                                 std::vector<TRT_TensorOrWeights> const& inputs,
+                                 std::vector<TRT_TensorOrWeights>* outputs) {
+  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+  // nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+  // TODO(jie): handle NHWC/NCHW transpose;
+  TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+  TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
+  reorder_rsck_to_kcrs(weights_rsck, &weights);
+  TRT_ShapedWeights biases(weights.type_);
+  int noutput = weights.shape_.d[0];
+  nvinfer1::DimsHW kernel_size;
+  kernel_size.h() = weights.shape_.d[2];
+  kernel_size.w() = weights.shape_.d[3];
+  TFAttrs attrs(node_def);
+
+  int h_index = 2;
+  int w_index = 3;
+  auto data_format = attrs.get<std::string>("data_format");
+  if (data_format == "NHWC") {
+    tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+                                 {0, 3, 1, 2});
+    h_index = 1;
+    w_index = 2;
+    // TODO(jie): transpose it
+  } else {
+    LOG(DEBUG) << "NCHW !!!!";
+  }
+  // TODO(jie): stride. (NHWC/NCHW)
+  auto tf_stride = attrs.get<std::vector<int>>("strides");
+  nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+  auto tensor_dim = tensor->getDimensions();
+  std::vector<std::pair<int, int>> padding;
+  // TODO(jie): padding.
+  if (attrs.get<std::string>("padding") == "SAME") {
+    // This is NCHW tensor with no batch dimension.
+    //  1 -> h
+    //  2 -> w
+    padding = createSamePadding(stride, kernel_size,
+                                {static_cast<int>(tensor_dim.d[h_index]),
+                                 static_cast<int>(tensor_dim.d[w_index])});
+  } else {
+    // return tensorflow::errors::Unimplemented(
+    //          "Current Conv2D cannot support padding other than SAME");
+    padding = {{0, 0}, {0, 0}};
+  }
+
+  if (padding[0].first != padding[0].second ||
+      padding[1].first != padding[1].second) {
+    // TODO(jie): handle asymmetric padding
+    // return tensorflow::errors::Unimplemented(
+    //         "Asymmetric padding not implemented yet");
+    auto padLayer = ctx.network()->addPadding(
+        *const_cast<nvinfer1::ITensor*>(tensor),
+        nvinfer1::DimsHW(padding[1].first, padding[0].first),
+        nvinfer1::DimsHW(padding[1].second, padding[0].second));
+    tensor = padLayer->getOutput(0);
+  }
+
+  nvinfer1::IConvolutionLayer* layer =
+      ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
+                                    noutput, kernel_size, weights, biases);
+
+  layer->setStride(stride);
+  layer->setPadding({padding[0].first, padding[1].first});
+  layer->setName(node_def.name().c_str());
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+  if (data_format == "NHWC") {
+    // TODO(jie): transpose it back!
+    output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1});
+  } else {
+    LOG(DEBUG) << "NCHW !!!!";
+  }
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPool(Converter& ctx,
+                               tensorflow::NodeDef const& node_def,
+                               std::vector<TRT_TensorOrWeights> const& inputs,
+                               std::vector<TRT_TensorOrWeights>* outputs) {
+  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+  TFAttrs attrs(node_def);
+
+  int h_index = 2;
+  int w_index = 3;
+  auto data_format = attrs.get<std::string>("data_format");
+  if (data_format == "NHWC") {
+    h_index = 1;
+    w_index = 2;
+    tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+                                 {0, 3, 1, 2});
+  } else {
+    LOG(DEBUG) << "NCHW !!!!";
+  }
+  nvinfer1::PoolingType type;
+  // TODO(jie): support other pooling type
+  if (node_def.op() == "MaxPool")
+    type = nvinfer1::PoolingType::kMAX;
+  else
+    return tensorflow::errors::Unimplemented("only supports Max pool");
+
+  // TODO(jie): NCHW
+  auto tf_stride = attrs.get<std::vector<int>>("strides");
+  nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+  auto tf_kernel = attrs.get<std::vector<int>>("ksize");
+  nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
+
+  auto tensor_dim = tensor->getDimensions();
+  std::vector<std::pair<int, int>> padding;
+  // TODO(jie): padding.
+  if (attrs.get<std::string>("padding") == "SAME") {
+    // This is NCHW tensor with no batch dimension.
+    //  1 -> h
+    //  2 -> w
+    padding = createSamePadding(
+        stride, ksize,
+        {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+  } else if (attrs.get<std::string>("padding") == "VALID") {
+    // No padding for valid padding here
+    LOG(DEBUG) << "no padding added for VALID padding in pool"
+               << node_def.name();
+    padding = {{0, 0}, {0, 0}};
+  } else {
+    return tensorflow::errors::Unimplemented(
+        "Current MaxPool cannot support padding other than SAME");
+  }
+
+  if (padding[0].first != padding[0].second ||
+      padding[1].first != padding[1].second) {
+    // TODO(jie): handle asymmetric padding
+    // return tensorflow::errors::Unimplemented(
+    //          "Asymmetric padding not implemented yet");
+    auto padLayer = ctx.network()->addPadding(
+        *const_cast<nvinfer1::ITensor*>(tensor),
+        nvinfer1::DimsHW(padding[1].first, padding[0].first),
+        nvinfer1::DimsHW(padding[1].second, padding[0].second));
+    tensor = padLayer->getOutput(0);
+  }
+
+  nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
+      *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
+
+  layer->setStride(stride);
+  layer->setPadding({padding[0].first, padding[1].first});
+  layer->setName(node_def.name().c_str());
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+  if (data_format == "NHWC") {
+    // TODO(jie): transpose it back!
+    output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1});
+  } else {
+    LOG(DEBUG) << "NCHW !!!!";
+  }
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertActivation(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    std::vector<TRT_TensorOrWeights> const& inputs,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+  nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
+      *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertScale(Converter& ctx,
+                                tensorflow::NodeDef const& node_def,
+                                std::vector<TRT_TensorOrWeights> const& inputs,
+                                std::vector<TRT_TensorOrWeights>* outputs) {
+  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+      !inputs.at(1).is_weights())
+    return tensorflow::errors::Unimplemented(
+        "only supports tensor op weight for now, at " + node_def.name());
+  // implement tensor binaryOp weight [channel wise] for now;
+  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+  // nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+
+  // TODO(jie): handle NHWC/NCHW transpose;
+  TRT_ShapedWeights weights = inputs.at(1).weights();
+  // nvinfer1::Weights empty_weights{weights.type, nullptr, 0};
+  TRT_ShapedWeights empty_weights(weights.type_);
+
+  TFAttrs attrs(node_def);
+
+  // transpose NHWC
+  auto data_format = attrs.get<std::string>("data_format");
+  if (data_format == "NHWC") {
+    tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+                                 {0, 3, 1, 2});
+    // TODO(jie): transpose it
+  } else {
+    LOG(DEBUG) << "NCHW !!!!";
+  }
+  nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+      *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
+      weights, empty_weights, empty_weights);
+
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+  if (data_format == "NHWC") {
+    // TODO(jie): transpose it back!
+    output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1});
+  } else {
+    LOG(DEBUG) << "NCHW !!!!";
+  }
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConst(Converter& ctx,
+                                tensorflow::NodeDef const& node_def,
+                                std::vector<TRT_TensorOrWeights> const& inputs,
+                                std::vector<TRT_TensorOrWeights>* outputs) {
+  auto const& weights_tensor = node_def.attr().at("value").tensor();
+
+  // get trt type & shape
+  TFAttrs attrs(node_def);
+  // nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+  tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
+
+  // create shaped weights as output
+  tensorflow::Tensor tensor;
+  if (!tensor.FromProto(weights_tensor))
+    return tensorflow::errors::Internal("cannot parse weight tensor proto: " +
+                                        node_def.name());
+
+  TRT_ShapedWeights weights(dtype);
+  if (!weights_tensor.float_val().empty()) {
+    LOG(DEBUG) << "SCALAR!!!" << node_def.name();
+    nvinfer1::Dims scalar_shape;
+    if (tensor.dims() > 0) {
+      LOG(DEBUG) << "dimensions: " << tensor.dims();
+      weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+                                  get_tensor_shape(tensor));
+    } else {
+      LOG(DEBUG) << "dimensions: " << tensor.dims();
+      scalar_shape.nbDims = 1;
+      scalar_shape.d[0] = 1;
+      scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+      for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
+        scalar_shape.d[i] = 0;
+        scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
+      }
+      weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+                                  scalar_shape);
+    }
+    // LOG(INFO) << " add: " << weights_tensor.float_val().data();
+    // LOG(INFO) << " value: " << (*weights_tensor.float_val().data());
+
+    // weights = ctx.get_temp_weights(dtype, scalar_shape);
+    // std::memcpy(const_cast<void*>(weights.values),
+    //           weights_tensor.float_val().data(), weights.size_bytes());
+  } else if (!weights_tensor.tensor_content().empty()) {
+    LOG(DEBUG) << "TENSOR!!!" << node_def.name();
+    weights = TRT_ShapedWeights(dtype, weights_tensor.tensor_content().data(),
+                                get_tensor_shape(tensor));
+  } else {
+    return tensorflow::errors::Unimplemented(
+        "not supported constant type, at " + node_def.name());
+  }
+  // pass the output
+  outputs->push_back(TRT_TensorOrWeights(weights));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertIdentity(
+    Converter& ctx, tensorflow::NodeDef const& node_def,
+    std::vector<TRT_TensorOrWeights> const& inputs,
+    std::vector<TRT_TensorOrWeights>* outputs) {
+  outputs->push_back(inputs.at(0));
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertBinary(Converter& ctx,
+                                 tensorflow::NodeDef const& node_def,
+                                 std::vector<TRT_TensorOrWeights> const& inputs,
+                                 std::vector<TRT_TensorOrWeights>* outputs) {
+  if (inputs.size() != 2)
+    return tensorflow::errors::FailedPrecondition(
+        "Binary ops require two tensor input, at " + node_def.name());
+
+  if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
+    return ConstantFoldBinary(ctx, node_def, inputs, outputs);
+
+  if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
+    return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
+                                inputs.at(1).weights(), outputs);
+
+  if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
+    return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
+                                inputs.at(0).weights(), outputs);
+
+  if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
+    return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
+                                inputs.at(1).tensor(), outputs);
+
+  return tensorflow::errors::Unknown("Binary op input error, at " +
+                                     node_def.name());
+}
+
+tensorflow::Status ConvertUnary(Converter& ctx,
+                                tensorflow::NodeDef const& node_def,
+                                std::vector<TRT_TensorOrWeights> const& inputs,
+                                std::vector<TRT_TensorOrWeights>* outputs) {
+  if (inputs.size() != 1)
+    return tensorflow::errors::FailedPrecondition(
+        "Unary ops require single tensor input, at " + node_def.name());
+
+  if (inputs.at(0).is_weights())
+    return ConstantFoldUnary(ctx, node_def, inputs, outputs);
+  else if (inputs.at(0).is_tensor())
+    return tensorflow::errors::Unimplemented(
+        "Unary op for tensor not supported, at " + node_def.name());
+
+  return tensorflow::errors::Unknown("Binary op input error, at " +
+                                     node_def.name());
+}
+
+tensorflow::Status ConvertReduce(Converter& ctx,
+                                 tensorflow::NodeDef const& node_def,
+                                 std::vector<TRT_TensorOrWeights> const& inputs,
+                                 std::vector<TRT_TensorOrWeights>* outputs) {
+  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+      !inputs.at(1).is_weights())
+    return tensorflow::errors::InvalidArgument(
+        "Input expects tensor and weights, at" + node_def.name());
+
+  // implement tensor binaryOp weight [channel wise] for now;
+  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+  auto dims = tensor->getDimensions();
+  // restore implicit batch dimension
+  int nbDims = dims.nbDims + 1;
+
+  TRT_ShapedWeights index_list = inputs.at(1).weights();
+
+  TFAttrs attrs(node_def);
+  // TODO(jie): handle data type
+  // auto data_type = attrs.get<nvinfer1::DataType>("T");
+  // index type here is done through TF type
+  //   so I can leverage their EnumToDataType for my cast
+  auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+  // auto keep_dims_flag = attrs.get<bool>("keep_dims");
+
+  // Only expect to handle INT32 as attributes for now
+  if (index_type != tensorflow::DataType::DT_INT32)
+    return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
+  // auto pad_data = const_cast<tensorflow::EnumToDataType<padding_type>::Type*>
+  //                  (pads.values);
+  auto index_list_data =
+      static_cast<int*>(const_cast<void*>(index_list.values_));
+  // auto index_list_data =
+  //       const_cast<tensorflow::EnumToDataType<index_type>::Type*>
+  //         (index_list.values);
+
+  // hack warning:
+  //   have to fall back to pool layer since reduce is not in public TRT yet.
+  if (nbDims != 4)
+    return tensorflow::errors::InvalidArgument(
+        "TRT only support reduce on 4 dimensional tensors, at" +
+        node_def.name());
+  if (index_list.count() > 2)
+    return tensorflow::errors::InvalidArgument(
+        "TRT cannot support reduce on more than 2 dimensions, at" +
+        node_def.name());
+
+  std::set<int> idx_set;
+  // we cannot operate on Channel. permutation flag used to transpose tensor
+  int permuted_index = -1;
+  for (int i = 0; i < index_list.count(); i++) {
+    if (index_list_data[i] == 0)
+      return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
+                                                 node_def.name());
+    if (index_list_data[i] == 1) permuted_index = 1;
+    idx_set.emplace(index_list_data[i]);
+  }
+
+  std::vector<int> permutation_order(nbDims);
+  nvinfer1::DimsHW pool_kernel;
+  if (permuted_index == 1) {
+    for (int i = 2; i < nbDims; i++) {
+      if (idx_set.count(i)) {
+        permuted_index = i;
+        break;
+      }
+    }
+    for (int i = 0; i < nbDims; i++) permutation_order[i] = i;
+
+    permutation_order[permuted_index] = 1;
+    permutation_order[1] = permuted_index;
+
+    // apply permutation before extracting dimension for pool_kernel
+    tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+                                 permutation_order);
+  }
+
+  // apply permutation before extracting dimension for pool_kernel
+  pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
+  pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
+
+  nvinfer1::ITensor* output_tensor;
+
+  if (node_def.op() == "Mean") {
+    nvinfer1::IPoolingLayer* layer =
+        ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
+                                  nvinfer1::PoolingType::kAVERAGE, pool_kernel);
+    output_tensor = layer->getOutput(0);
+  } else {
+    return tensorflow::errors::Unimplemented(
+        "Op not supported " + node_def.op() + " , at " + node_def.name());
+  }
+  if (permuted_index != -1) {
+    // apply permutation before extracting dimension for pool_kernel
+    output_tensor = ctx.transposeTensor(
+        const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
+  }
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPad(Converter& ctx,
+                              tensorflow::NodeDef const& node_def,
+                              std::vector<TRT_TensorOrWeights> const& inputs,
+                              std::vector<TRT_TensorOrWeights>* outputs) {
+  if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+      !inputs.at(1).is_weights())
+    return tensorflow::errors::InvalidArgument(
+        "Input expects tensor and weights, at" + node_def.name());
+
+  // implement tensor binaryOp weight [channel wise] for now;
+  nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+  auto dims = tensor->getDimensions();
+  // restore implicit batch dimension
+  int nbDims = dims.nbDims + 1;
+
+  TRT_ShapedWeights pads = inputs.at(1).weights();
+
+  TFAttrs attrs(node_def);
+  // padding type here is done through TF type
+  //   so I can leverage their EnumToDataType for my cast
+  auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
+  // TODO(jie): handle data type conversion for TRT?
+  // auto data_type = attrs.get<nvinfer1::DataType>("T");
+
+  if (pads.shape_.d[0] != nbDims || pads.shape_.d[1] != 2)
+    return tensorflow::errors::InvalidArgument(
+        "Pad only supports explicit padding on 4 dimensional tensor, at " +
+        node_def.name());
+
+  // Only expect to handle INT32 as attributes for now
+  if (padding_type != tensorflow::DataType::DT_INT32)
+    return tensorflow::errors::Unimplemented(
+        "Tpaddings supports only DT_INT32");
+  // auto pad_data = const_cast<tensorflow::EnumToDataType<padding_type>::Type*>
+  //                  (pads.values);
+  auto pad_data = static_cast<int*>(const_cast<void*>(pads.values_));
+
+  std::vector<int32_t> pad_index;
+  for (int i = 0; i < nbDims; i++) {
+    if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
+      pad_index.push_back(i);
+  }
+
+  // no padding at all, we should exit
+  if (pad_index.size() == 0) {
+    outputs->push_back(inputs.at(0));
+    return tensorflow::Status::OK();
+  }
+
+  // only supports padding on less than 2 axis GIE-2579
+  if (pad_index.size() > 2)
+    return tensorflow::errors::InvalidArgument(
+        "Padding layer does not support padding on > 2");
+
+  // padding on batch dimension is not supported
+  if (pad_index[0] == 0)
+    return tensorflow::errors::InvalidArgument(
+        "Padding layer does not support padding on batch dimension");
+
+  // not doing the legit thing here. ignoring padding on dim 1 and 3;
+  // TODO(jie): implement pad as uff parser
+  if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
+    return tensorflow::errors::Unimplemented(
+        "Padding layer does not support padding on dimension 1 and 3 yet");
+
+  bool legit_pad = true;
+  nvinfer1::DimsHW pre_padding(0, 0);
+  nvinfer1::DimsHW post_padding(0, 0);
+
+  std::vector<int32_t> permuted_pad_index(pad_index);
+  if (pad_index[0] == 1) {
+    legit_pad = false;
+    tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+                                 {0, 3, 2, 1});
+    permuted_pad_index[0] = 3;
+  }
+
+  for (size_t i = 0; i < pad_index.size(); i++) {
+    int index = pad_index[i];
+    if (permuted_pad_index[i] == 2) {
+      pre_padding.h() = pad_data[index * 2];
+      post_padding.h() = pad_data[index * 2 + 1];
+    } else if (permuted_pad_index[i] == 3) {
+      pre_padding.w() = pad_data[index * 2];
+      post_padding.w() = pad_data[index * 2 + 1];
+    }
+  }
+
+  nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
+      *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
+  nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+  if (!legit_pad)
+    output_tensor = ctx.transposeTensor(
+        const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
+
+  outputs->push_back(TRT_TensorOrWeights(output_tensor));
+  return tensorflow::Status::OK();
+}
+
+void Converter::register_op_converters() {
+  // vgg_16 slim implementation
+  _op_registry["Placeholder"] = ConvertPlaceholder;
+  _op_registry["Conv2D"] = ConvertConv2D;
+  _op_registry["Relu"] = ConvertActivation;
+  _op_registry["MaxPool"] = ConvertPool;
+  // This could be really handled as ConvertBinary
+  _op_registry["BiasAdd"] = ConvertScale;
+  _op_registry["Const"] = ConvertConst;
+  // _op_registry["MatMul"] = ConvertFullyConnected; // not used in vgg
+  // TODO(ben,jie): this is a temp hack.
+  _op_registry["Identity"] = ConvertIdentity;  // Identity should be removed
+  // _op_registry["AvgPool"] = ConvertPool;
+
+  // resnet_50_v1 slim implementation
+  _op_registry["Add"] = ConvertBinary;
+  _op_registry["Mul"] = ConvertBinary;
+  _op_registry["Sub"] = ConvertBinary;
+  _op_registry["Rsqrt"] = ConvertUnary;
+  _op_registry["Mean"] = ConvertReduce;
+  _op_registry["Pad"] = ConvertPad;
+  // TODO(ben,jie): Add more ops
+}
+
+}  // namespace
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+    const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+    const std::vector<std::pair<int, int>>& input_inds,
+    const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
+    size_t max_workspace_size, const ShapeMap& shape_map,
+    tensorflow::NodeDef* trt_node) {
+  // Visit nodes in reverse topological order and construct the TRT network.
+
+  // Toposort
+  std::vector<tensorflow::Node*> order_vec;
+  tensorflow::GetPostOrder(graph, &order_vec);
+  // Select just the subgraph
+  std::list<tensorflow::Node*> order;
+  for (tensorflow::Node* node : order_vec) {
+    if (subgraph_node_ids.count(node->id())) {
+      // order.push_back(node);
+      order.push_front(node);  // we want topological order to contstruct the
+                               // network layer by layer
+    }
+  }
+  // topological order is needed to build TRT network
+  LOG(DEBUG) << "BUILDING 1";
+
+  //  nvinfer1::ILogger::Severity verbosity =
+  //  nvinfer1::ILogger::Severity::kWARNING;
+  tensorflow::tensorrt::Logger trt_logger;
+  //  TRT_Logger trt_logger(verbosity);
+
+  LOG(DEBUG) << "BUILDING 2";
+
+  auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
+  if (!trt_builder) {
+    return tensorflow::errors::Internal(
+        "failed to create TensorRT builder object");
+  }
+
+  LOG(DEBUG) << "BUILDING 3";
+
+  auto trt_network = infer_object(trt_builder->createNetwork());
+  if (!trt_network) {
+    return tensorflow::errors::Internal(
+        "failed to create TensorRT network object");
+  }
+
+  LOG(DEBUG) << "BUILDING 4";
+
+  // Build the network
+  Converter converter(trt_network.get());
+
+  LOG(DEBUG) << "BUILDING 5";
+  std::vector<std::string> input_names;
+  std::vector<tensorflow::DataType> input_dtypes;
+  for (std::pair<int, int> const& input : input_inds) {
+    LOG(DEBUG) << "parsing input!!!!!";
+    int node_id = input.first;
+    int output_idx = input.second;
+    tensorflow::Node* node = graph.FindNodeId(node_id);
+    auto node_name = node->name();
+    input_names.push_back(node_name);  // insert original node name without port
+    // TODO(jie): alternative :)
+    // tensorflow::DataType tf_dtype = node->output_type(output_idx);
+    if (shape_map.count(node_name) == 0)
+      return tensorflow::errors::Internal("failed to find input node: " +
+                                          node_name);
+
+    auto input_entry_vec = shape_map.at(node_name);
+    if (static_cast<int>(input_entry_vec.size()) < output_idx)
+      return tensorflow::errors::Internal(
+          "accessing output index of: " + std::to_string(output_idx) +
+          ", at node: " + node_name + "with output entry from shape_map: " +
+          std::to_string(input_entry_vec.size()));
+
+    auto input_entry = input_entry_vec.at(output_idx);
+
+    tensorflow::DataType tf_dtype = input_entry.second;
+    input_dtypes.push_back(tf_dtype);
+
+    nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+    TF_CHECK_OK(convert_dtype(tf_dtype, &dtype));
+
+    LOG(DEBUG) << "accessing output index of: " << std::to_string(output_idx)
+               << ", at node: " << node_name
+               << "with output entry from shape_map: "
+               << std::to_string(input_entry_vec.size());
+    // TODO(ben,jie): update TRT input format/dimension
+    nvinfer1::DimsCHW input_dim_psuedo_chw;
+    for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
+
+    for (int i = 1; i < input_entry.first.dims(); i++) {
+      LOG(DEBUG) << "dimension: " << i
+                 << " , size: " << input_entry.first.dim_size(i);
+      input_dim_psuedo_chw.d[i - 1] = input_entry.first.dim_size(i);
+    }
+
+    // TODO(ben,jie): proper way to restore input tensor name?
+    auto input_tensor_name = node_name;
+    if (output_idx != 0)
+      input_tensor_name = node_name + ":" + std::to_string(output_idx);
+
+    nvinfer1::ITensor* input_tensor = converter.network()->addInput(
+        input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
+
+    if (!input_tensor)
+      return tensorflow::errors::InvalidArgument(
+          "Failed to create Input layer");
+    LOG(DEBUG) << "input tensor name :" << input_tensor_name;
+
+    if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
+      return tensorflow::errors::AlreadyExists(
+          "output tensor already exists for op: " + input_tensor_name);
+  }
+
+  LOG(DEBUG) << "finished sorting";
+
+  for (const tensorflow::Node* node : order) {
+    tensorflow::NodeDef const& node_def = node->def();
+    LOG(DEBUG) << "converting node: " << node_def.name() << " , "
+               << node_def.op();
+    TF_RETURN_IF_ERROR(converter.convert_node(node_def));
+  }
+
+  LOG(DEBUG) << "finished conversion";
+
+  // Gather output metadata
+  std::vector<std::string> output_names;
+  std::vector<tensorflow::DataType> output_dtypes;
+  for (std::pair<int, int> const& output : output_inds) {
+    int node_id = output.first;
+    int output_idx = output.second;
+    tensorflow::Node* node = graph.FindNodeId(node_id);
+    std::string op_name = node->name();
+    std::string tensor_name = op_name;
+    if (output_idx != 0)
+      tensor_name = tensor_name + ":" + std::to_string(output_idx);
+    LOG(DEBUG) << "output tensor name: " << tensor_name;
+    output_names.push_back(tensor_name);
+    auto tensor_or_weights = converter.get_tensor(tensor_name);
+    if (!tensor_or_weights.is_tensor()) {
+      return tensorflow::errors::InvalidArgument(
+          "Output node is weights not tensor");
+    }
+    nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
+    if (!tensor) {
+      return tensorflow::errors::NotFound("Output tensor not found: " +
+                                          tensor_name);
+    }
+    converter.network()->markOutput(*tensor);
+    tensorflow::DataType tf_dtype = node->output_type(output_idx);
+    output_dtypes.push_back(tf_dtype);
+    nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
+    TF_RETURN_IF_ERROR(convert_dtype(tf_dtype, &trt_dtype));
+    tensor->setType(trt_dtype);
+  }
+
+  LOG(DEBUG) << "finished output";
+
+  // Build the engine
+  trt_builder->setMaxBatchSize(max_batch_size);
+  trt_builder->setMaxWorkspaceSize(max_workspace_size);
+  LOG(INFO) << "starting build engine";
+  // TODO(ben,jie): half2 and int8 mode support
+  std::string engine_plan_string;
+  {
+    auto trt_engine =
+        infer_object(trt_builder->buildCudaEngine(*converter.network()));
+    LOG(INFO) << "built network";
+    auto engine_plan = infer_object(trt_engine->serialize());
+    LOG(INFO) << "serialized engine";
+    const char* engine_plan_data =
+        static_cast<const char*>(engine_plan->data());
+    engine_plan_string = std::move(
+        std::string(engine_plan_data, engine_plan_data + engine_plan->size()));
+  }
+  // std::ofstream engine_out("mini.engine");
+  // engine_out << engine_plan_string;
+  // engine_out.close();
+
+  LOG(INFO) << "finished engine";
+
+  // Build the TRT op
+  // TODO(sami,ben,jie): proper naming!
+  static int static_id = 0;
+  tensorflow::NodeDefBuilder op_builder(
+      "my_trt_op" + std::to_string(static_id++), "TRTEngineOp");
+  std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    int output_idx = input_inds.at(i).second;
+    // we wired up the input here already, it is redundant to do it again in
+    //  ConvertSubGraphToTensorRT(convert_graph.cc)
+    auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(input_names.at(i),
+                           output_idx, input_dtypes.at(i));
+    income_edges.push_back(incoming_edge);
+  }
+  tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut>
+    input_list(income_edges);
+  op_builder.Input(input_list);
+
+  LOG(INFO) << "finished op preparation";
+
+  auto status = op_builder.Attr("serialized_engine", engine_plan_string)
+                    .Attr("input_nodes", input_names)
+                    .Attr("output_nodes", output_names)
+                    .Attr("OutT", output_dtypes)
+                    .Finalize(trt_node);
+
+  LOG(INFO) << status.ToString();
+  LOG(INFO) << "finished op building";
+
+  return tensorflow::Status::OK();
+}
+
+}  // namespace convert
+}  // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
new file mode 100644
index 0000000..a624582
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+
+#include <set>
+#include <vector>
+#include <utility>
+
+#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorrt {
+namespace convert {
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+    const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+    const std::vector<std::pair<int, int>>&
+        input_inds,  // {node_id, output_idx}
+    const std::vector<std::pair<int, int>>&
+        output_inds,  // {node_id, output_idx}
+    size_t max_batch_size, size_t max_workspace_size, const ShapeMap& shape_map,
+    tensorflow::NodeDef* trt_node);
+}  // namespace convert
+}  // namespace tensorrt
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
diff --git a/tensorflow/contrib/tensorrt/convert/inferShapes.cc b/tensorflow/contrib/tensorrt/convert/inferShapes.cc
new file mode 100644
index 0000000..c7f0f00
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/inferShapes.cc
@@ -0,0 +1,125 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
+#include <functional>
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb_text.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+
+namespace tensorflow {
+namespace trt {
+std::vector<tensorflow::DataType> getTypes(const tensorflow::OpDef& op,
+                                           const tensorflow::NodeDef& nd,
+                                           bool inp = true) {
+  const auto& attrMap = nd.attr();
+  auto getType = [&attrMap](decltype(
+                     op.input_arg(0)) a) -> std::vector<tensorflow::DataType> {
+    std::vector<tensorflow::DataType> tvec;
+    if (!a.type_list_attr().empty()) {  // get the list types
+      const auto& tl = attrMap.at(a.type_list_attr()).list();
+      int tsize = tl.type_size();
+      tvec.reserve(tsize);
+      for (int t = 0; t < tsize; t++) {
+        tvec.push_back(tl.type(t));
+      }
+      return tvec;
+    }
+    tensorflow::DataType cType = tensorflow::DT_INVALID;
+    if (a.type() != tensorflow::DT_INVALID) {  // get defined types
+      cType = a.type();
+    } else if (!a.type_attr().empty()) {
+      cType = attrMap.at(a.type_attr()).type();
+    }
+    if (!a.number_attr().empty()) {  // numbertypes
+      int64 nTensors = attrMap.at(a.number_attr()).i();
+      tvec = std::vector<tensorflow::DataType>(nTensors, cType);
+      return tvec;
+    }
+    tvec.push_back(cType);
+    return tvec;
+  };
+  std::vector<tensorflow::DataType> types;
+  if (inp) {
+    int n_inputs = op.input_arg_size();
+    for (int i = 0; i < n_inputs; i++) {
+      auto tout = getType(op.input_arg(i));
+      LOG(DEBUG) << "Node= " << nd.name() << " #inputs" << tout.size();
+      types.insert(types.end(), tout.begin(), tout.end());
+    }
+  } else {
+    int n_outputs = op.output_arg_size();
+    // types.resize(n_outputs);
+    for (int i = 0; i < n_outputs; i++) {
+      auto tout = getType(op.output_arg(i));
+      LOG(DEBUG) << "Node= " << nd.name() << " #outputs" << tout.size();
+      types.insert(types.end(), tout.begin(), tout.end());
+    }
+  }
+  return types;
+}
+
+tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
+                               const std::vector<std::string>& output_names,
+                               ShapeMap& shapes) {
+  tensorflow::Graph g(OpRegistry::Global());
+  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+      tensorflow::GraphConstructorOptions(), graph_def, &g));
+  std::vector<tensorflow::Node*> POnodes;
+  tensorflow::GetPostOrder(g, &POnodes);
+  tensorflow::ShapeRefiner refiner(graph_def.versions().producer(),
+                                   OpRegistry::Global());
+  for (auto n = POnodes.rbegin(); n != POnodes.rend(); ++n) {
+    TF_CHECK_OK(refiner.AddNode(*n));
+  }
+
+  auto shape2PTS = [](tensorflow::shape_inference::InferenceContext* ic,
+                      const tensorflow::shape_inference::ShapeHandle& sh)
+      -> tensorflow::PartialTensorShape {
+    std::vector<int64> dims;
+    int64 rank = ic->Rank(sh);
+    for (int64 i = 0; i < rank; i++) {
+      auto dh = ic->Dim(sh, i);
+      dims.push_back(ic->Value(dh));
+    }
+    return tensorflow::PartialTensorShape(dims);
+  };
+  for (const auto& n : POnodes) {
+    auto ic = refiner.GetContext(n);
+    if (ic) {
+      int nOuts = ic->num_outputs();
+      auto types = getTypes(n->op_def(), n->def(), false);
+      std::vector<
+          std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>
+          SAT;
+      for (int i = 0; i < nOuts; i++) {
+        auto PTS = shape2PTS(ic, ic->output(i));
+        SAT.push_back({PTS, types.at(i)});
+      }
+      shapes[n->name()] = SAT;
+    } else {
+      LOG(WARNING) << "Node " << n->name() << " doesn't have InferenceContext!";
+    }
+  }
+  return tensorflow::Status::OK();
+}
+}  // namespace trt
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/inferShapes.h b/tensorflow/contrib/tensorrt/convert/inferShapes.h
new file mode 100644
index 0000000..b94f1ee
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/inferShapes.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+#include <utility>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+typedef std::unordered_map<std::string,
+                           std::vector<std::pair<tensorflow::PartialTensorShape,
+                                                 tensorflow::DataType>>>
+    ShapeMap;
+namespace tensorflow {
+namespace trt {
+tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
+                               const std::vector<std::string>& output_names,
+                               ShapeMap& shapes);
+}
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
new file mode 100644
index 0000000..a1524a5
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+#include <cuda_runtime_api.h>
+#include <sstream>
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+// Use TF logging f
+
+
+namespace tensorflow {
+static ::tensorflow::tensorrt::Logger gLogger;
+
+using namespace nvinfer1;
+
+namespace tensorrt {
+
+TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
+  // char *gieModelStream{nullptr};
+  // size_t size{0};
+
+  // read serialized_engine
+  std::string serialized_engine;
+  OP_REQUIRES_OK(context,
+                 context->GetAttr("serialized_engine", &serialized_engine));
+
+  // register input output node name in trt_sub_graph
+  OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
+  OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
+
+  // TODO(samikama) runtime should be taken from a resourcemanager as well.
+  //  Only engine should be in the op and context and runtime should be taken
+  //  from resourcemanager
+  IRuntime* infer = createInferRuntime(gLogger);
+  trt_engine_ptr_.reset(infer->deserializeCudaEngine(
+      serialized_engine.c_str(), serialized_engine.size(), nullptr));
+
+  trt_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
+  // runtime is safe to delete after engine creation
+  infer->destroy();
+  std::stringstream oss;
+  // debug iterate through all binding instances
+  for (int i = 0; i < trt_engine_ptr_->getNbBindings(); i++) {
+    LOG(INFO) << "index: " << i
+              << ", binding name: " << trt_engine_ptr_->getBindingName(i);
+
+    if (trt_engine_ptr_->bindingIsInput(i)) {
+      LOG(INFO) << "INPUT";
+    } else {
+      LOG(INFO) << "OUTPUT";
+    }
+    oss << "Dimension: ";
+    auto dims = trt_engine_ptr_->getBindingDimensions(i);
+    oss << " nbDims: " << dims.nbDims << " -> ";
+    for (int j = 0; j < Dims::MAX_DIMS; j++) {
+      oss << dims.d[j] << ", ";
+    }
+    LOG(INFO) << oss.str();
+    oss.str("");
+    switch (trt_engine_ptr_->getBindingDataType(i)) {
+      case nvinfer1::DataType::kFLOAT:
+        LOG(INFO) << "data type float" << std::endl;
+        break;
+      case nvinfer1::DataType::kHALF:
+        LOG(INFO) << "data type half" << std::endl;
+        break;
+      case nvinfer1::DataType::kINT8:
+        LOG(INFO) << "data type int8" << std::endl;
+        break;
+    }
+  }
+
+  // CHECK_NE(cudaStreamCreate(&stream_),0); // logic here is wrong
+  // cudaStreamCreate(&stream_);
+}
+
+void TRTEngineOp::Compute(OpKernelContext* context) {
+  int nbBindings = context->num_inputs() + context->num_outputs();
+  // TODO(jjsjann123) multiple input/output
+  std::vector<void*> buffers(nbBindings);
+
+  size_t bindingIndex;
+  int nbBatch = 0;
+  bool valid = true;
+  for (int i = 0; i < context->num_inputs(); i++) {
+    // Grab the input tensor
+    bindingIndex = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
+
+    const Tensor& input_tensor = context->input(i);
+    const TensorShape& input_shape = input_tensor.shape();
+    if (i == 0) {
+      nbBatch = input_shape.dim_size(0);
+    } else if (nbBatch != input_shape.dim_size(0)) {
+      valid = false;
+      break;
+    }
+    // int64 input_shape.dim_size(int d)
+    // int input_shape.dims()
+    switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
+      case nvinfer1::DataType::kFLOAT:
+        LOG(INFO) << "float";
+        buffers[bindingIndex] = (void*)(input_tensor.flat<float>().data());
+        break;
+      case nvinfer1::DataType::kHALF:
+        LOG(INFO) << "half";
+        // buffers[bindingIndex] = (void*)input_tensor.flat<float16>().data();
+        break;
+      case nvinfer1::DataType::kINT8:
+        LOG(INFO) << "int8";
+        // buffers[bindingIndex] = (void*)input_tensor.flat<int8>().data();
+        break;
+    }
+  }
+
+  if (!valid) LOG(WARNING) << "input data inconsistent batch size";
+
+  for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
+    // This is bad that we have to reallocate output buffer every run.
+    // Create an output tensor
+    bindingIndex = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
+    Tensor* output_tensor = NULL;
+
+    TensorShape output_shape;
+    if (bindingIndex != -1) {
+      LOG(INFO) << "got binding " << bindingIndex;
+      auto dims = trt_engine_ptr_->getBindingDimensions(bindingIndex);
+      std::vector<int> trt_shape(dims.nbDims + 1);
+      trt_shape[0] = nbBatch;
+      for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
+      TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
+                                  &output_shape);
+    } else {
+      LOG(INFO) << "no binding ";
+      break;
+    }
+
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(i, output_shape, &output_tensor));
+    // buffers[bindingIndex] = (void*)output_tensor->flat<float>();
+    // buffers[bindingIndex] = output_tensor->flat<float>().data();
+    switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
+      case nvinfer1::DataType::kFLOAT:
+        LOG(INFO) << "float";
+        buffers[bindingIndex] =
+            reinterpret_cast<void*>(output_tensor->flat<float>().data());
+        break;
+      case nvinfer1::DataType::kHALF:
+        LOG(INFO) << "half";
+        // buffers[bindingIndex] = (void*)output_tensor->flat<float16>().data();
+        break;
+      case nvinfer1::DataType::kINT8:
+        LOG(INFO) << "int8";
+        // buffers[bindingIndex] = (void*)output_tensor->flat<int8>().data();
+        break;
+    }
+  }
+  // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+  const cudaStream_t* stream = CHECK_NOTNULL(
+      reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+                                                ->stream()
+                                                ->implementation()
+                                                ->CudaStreamMemberHack()));
+
+  trt_context_ptr_->enqueue(nbBatch, &buffers[0], *stream, nullptr);
+  cudaStreamSynchronize(*stream);
+}
+
+REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
+}  // namespace tensorrt
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
new file mode 100644
index 0000000..631fc11
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+
+#include <NvInfer.h>
+#include <cuda_runtime_api.h>
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+namespace tensorrt {
+class Logger;
+class TRTEngineOp : public OpKernel {
+ public:
+  explicit TRTEngineOp(OpKernelConstruction* context);
+
+  void Compute(OpKernelContext* context) override;
+
+ private:
+  template <typename T>
+  struct Destroyer {
+    void operator()(T* d) { d->destroy(); }
+  };
+  template <typename T>
+  using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
+  destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
+  // TODO(samikama) context should go to a resource manager!
+  destroyed_ptr<nvinfer1::IExecutionContext> trt_context_ptr_;
+  std::vector<string> input_nodes_;
+  std::vector<string> output_nodes_;
+};
+
+}  // namespace tensorrt
+
+}  // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc
new file mode 100644
index 0000000..545a4aa
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+// Use TF logging for TensorRT informations
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+//------------------------------------------------------------------------------
+namespace tensorflow {
+
+//------------------------------------------------------------------------------
+namespace tensorrt {
+
+void Logger::log(Severity severity, const char* msg) {
+  // suppress info-level messages
+  switch (severity) {
+    case Severity::kINFO: {  // mark TRT info messages as debug!
+      LOG(DEBUG) << msg;
+      break;
+    }
+    case Severity::kWARNING: {
+      LOG(WARNING) << msg;
+      break;
+    }
+    case Severity::kERROR: {
+      LOG(ERROR) << msg;
+      break;
+    }
+    case Severity::kINTERNAL_ERROR: {
+      LOG(FATAL) << msg;
+      break;
+    }
+    // This is useless for now. But would catch it in future if enum changes. It
+    // is always good to have default case!
+    default: {
+      LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
+      break;
+    }
+  }
+}
+
+}  // namespace tensorrt
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h
new file mode 100644
index 0000000..10a78b7
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.h
@@ -0,0 +1,41 @@
+// -*- c++ -*-
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+
+// Use TF logging f
+#include <NvInfer.h>
+#include <string>
+
+//------------------------------------------------------------------------------
+namespace tensorflow {
+
+//------------------------------------------------------------------------------
+namespace tensorrt {
+
+// Logger for GIE info/warning/errors
+class Logger : public nvinfer1::ILogger {
+  void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
+
+ private:
+  std::string name_;
+};
+
+}  // namespace tensorrt
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
new file mode 100644
index 0000000..38d3707
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+namespace shape_inference {
+extern Status TRTEngineOpShapeInference(InferenceContext* c);
+}
+
+REGISTER_OP("TRTEngineOp")
+    .Attr("serialized_engine: string")
+    .Attr("input_nodes: list(string)")
+    .Attr("output_nodes: list(string)")
+    .Attr("InT: list({int8, float16, float32})")
+    .Attr("OutT: list({int8, float16, float32})")
+    .Input("in_tensor: InT")
+    .Output("out_tensor: OutT")
+    .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
new file mode 100644
index 0000000..4aeea48
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -0,0 +1,8 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
new file mode 100644
index 0000000..ce78d32
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
@@ -0,0 +1,35 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+if platform.system() != "Windows":
+  # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
+  from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
+
+  from tensorflow.contrib.util import loader
+  from tensorflow.python.platform import resource_loader
+  # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
+
+  _trt_engine_op = loader.load_op_library(
+      resource_loader.get_path_to_datafile("_trt_engine_op.so"))
+else:
+  raise RuntimeError("Windows platforms are not supported")
+
+
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
new file mode 100644
index 0000000..a66afa8
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -0,0 +1,91 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper conversion to trt_graph."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
+from tensorflow.python.util import compat
+import tensorflow as tf
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+
+
+def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20):
+  """Python wrapper for the TRT transormation.
+
+
+  Args:
+    input_graph_def: GraphDef object containing a model to be transformed.
+    outputs: List of node names for the model outputs.
+    max_batch_size: max size for the input batch
+    max_workspace_size: parameter to control memory allocation (in Bytes)
+
+  Returns:
+    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+  """
+
+  # with errors.raise_exception_on_not_ok_status() as status:
+  #   output_graph_def_string = trt_convert(
+  #       input_graph_def_string,outputs,
+  #       max_batch_size,max_workspace_size, status)
+  g = tf.Graph()
+  with g.as_default():
+    tf.import_graph_def(input_graph_def, name="")
+  rewriter_config = rewriter_config_pb2.RewriterConfig()
+  rewriter_config.optimizers.append('layout')
+  rewriter_config.optimizers.append('constfold')
+
+  # mark output nodes as fetch
+  train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+  for node_name in outputs:
+    out_node = g.get_operation_by_name(node_name)
+    for i in range(0,len(out_node.outputs)):
+      train_op.append(out_node.outputs[0])
+
+  # constant folding
+  mg = meta_graph.create_meta_graph_def(graph=g)
+  meta_graph.add_collection_def(mg, ops.GraphKeys.TRAIN_OP)
+  optimized_graph_def_str = \
+    tf_optimizer.OptimizeGraph(rewriter_config, mg).SerializeToString()
+
+  # TODO(sami): Fix this when we can return status from C++ library
+  # There is a problem with the TF internal library setup that doesn't allow us to return a status object from C++.
+  # Thus we return a  pair or strings where first one is encoded status and the second one is the
+  # transformed graphs protobuf string.
+  out = trt_convert(
+      optimized_graph_def_str ,outputs,
+      max_batch_size,max_workspace_size)
+  status = out[0]
+  output_graph_def_string = out[1]
+  del optimized_graph_def_str #save some memory
+  if len(status) < 2:
+    raise _impl.UnknownError(None,None,status)
+  if status[:2] != "OK":
+    msg=status.split(";")
+    if len(msg) == 1:
+      raise RuntimeError("Status message is malformed {}".format(status))
+    raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
+  output_graph_def = graph_pb2.GraphDef()
+  output_graph_def.ParseFromString(output_graph_def_string)
+  del output_graph_def_string #save some memory
+  return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
new file mode 100644
index 0000000..41da528
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -0,0 +1,259 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/tensorrt/segment/union_find.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+//------------------------------------------------------------------------------
+namespace tensorrt {
+namespace segment {
+
+//------------------------------------------------------------------------------
+namespace {
+
+//------------------------------------------------------------------------------
+bool CanContractEdge(const tensorflow::Edge* edge,
+                     const tensorflow::Graph& graph) {
+  const tensorflow::Node* src = edge->src();
+  const tensorflow::Node* dst = edge->dst();
+
+  // Can't contract edge if doing so would cause a cycle in the
+  // graph. So, if there is a directed path from 'src' to 'dst', other
+  // than 'edge' (or any other direct edge from 'src' to 'dst'), then
+  // combining 'src' and 'dst' will cause a cycle along that path.
+  //
+  // In practice, to avoid modifying the graph and to take advantage
+  // of existing graph functions, we perform an equivalent.
+  //   1. Get all nodes incoming to 'dst', excluding 'src'
+  //   2. Reverse DFS from those nodes
+  //   3. If reverse DFS reaches 'src' then we have a cycle
+  std::vector<tensorflow::Node*> dfs_start_nodes;
+  for (tensorflow::Node* node : dst->in_nodes()) {
+    if (node != src) {
+      dfs_start_nodes.push_back(node);
+    }
+  }
+
+  bool is_cycle = false;
+  if (!dfs_start_nodes.empty()) {
+    tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
+                               [&is_cycle, src](tensorflow::Node* node) {
+                                 if (node == src) {
+                                   is_cycle = true;
+                                 }
+                               });
+  }
+
+  return !is_cycle;
+}
+
+//------------------------------------------------------------------------------
+void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
+                  std::vector<const tensorflow::Edge*>* remove_edges) {
+  // Transfer all inputs and outputs of 'dst' to 'src' except edges
+  // connecting the two.
+  tensorflow::Node* src = edge->src();
+  tensorflow::Node* dst = edge->dst();
+
+  // We can use '0' for input/output index because we don't need them
+  // to be accurate for the way we are using the graph.
+  std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
+                                                dst->in_edges().end());
+  for (const tensorflow::Edge* in_edge : in_edges) {
+    if (in_edge->src() != src) {
+      tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+      if (e->src() == graph->source_node()) {
+        graph->AddEdge(e->src(), e->src_output(), src,
+                       tensorflow::Graph::kControlSlot);
+      } else {
+        graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
+      }
+    }
+  }
+
+  std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
+                                                 dst->out_edges().end());
+  for (const tensorflow::Edge* out_edge : out_edges) {
+    tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+    if (e->dst() == graph->sink_node()) {
+      graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
+                     e->dst_input());
+    } else {
+      graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
+    }
+  }
+
+  // Return the edges that must be removed to disconnect 'dst' from
+  // the graph. We don't actually remove 'dst' since the caller holds
+  // references to all the nodes.
+  for (const auto& in_edge : dst->in_edges()) {
+    remove_edges->push_back(in_edge);
+  }
+  for (const auto& out_edge : dst->out_edges()) {
+    remove_edges->push_back(out_edge);
+  }
+}
+
+}  // namespace
+
+//------------------------------------------------------------------------------
+tensorflow::Status SegmentGraph(
+    const tensorflow::GraphDef& gdef,
+    const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+    const SegmentOptions& options, SegmentNodesVector* segments) {
+  // Create a Graph representation of the GraphDef.
+  tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+                                             gdef.library());
+  tensorflow::Graph graph(flib);
+  TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+      tensorflow::GraphConstructorOptions(), gdef, &graph));
+
+  // tensorflow::DumpGraph("Pre-Segment", &graph);
+
+  // Use a union-find to collect the nodes that belong to the same
+  // segment. A node value of nullptr indicates that the node is not a
+  // candidate for TRT.
+  std::vector<UnionFind<tensorflow::Node*>> node_segments;
+  for (int i = 0; i < graph.num_node_ids(); ++i) {
+    tensorflow::Node* node = graph.FindNodeId(i);
+    if (!candidate_fn(node->def())) {
+      node = nullptr;
+    }
+    node_segments.emplace_back(node);
+  }
+
+  // Visit nodes in reverse topological order and use edge
+  // contraction to merge candidate nodes.
+  std::vector<tensorflow::Node*> order;
+  tensorflow::GetPostOrder(graph, &order);
+
+  for (const tensorflow::Node* node : order) {
+    // All output nodes of 'node' have been visited...
+    VLOG(2) << "Trying node " << node->name();
+
+    // 'node' must be a TRT candidate...
+    if (node_segments[node->id()].Value() == nullptr) {
+      VLOG(2) << "... not a TRT candidate";
+      continue;
+    }
+
+    // Contract output edges to combine 'node' with output
+    // nodes. Iterate since combining two nodes may unblock other
+    // combining.
+    while (true) {
+      std::set<const tensorflow::Edge*> contract_edges;
+      for (const tensorflow::Edge* out_edge : node->out_edges()) {
+        VLOG(2) << "... out node " << out_edge->dst()->name();
+
+        // Out node must be TRT candidate...
+        if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
+          VLOG(2) << "... ... not a TRT candidate";
+          continue;
+        }
+
+        if (CanContractEdge(out_edge, graph)) {
+          VLOG(2) << "... ... can contract";
+          contract_edges.insert(out_edge);
+        } else {
+          VLOG(2) << "... ... cannot contract, would form cycle";
+        }
+      }
+
+      if (contract_edges.empty()) {
+        break;
+      }
+
+      // Contract edges and collect the adjacent nodes into the same
+      // segment/subgraph.
+      while (!contract_edges.empty()) {
+        const tensorflow::Edge* contract_edge = *contract_edges.begin();
+        const tensorflow::Node* src = contract_edge->src();
+        const tensorflow::Node* dst = contract_edge->dst();
+
+        VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
+        node_segments[src->id()].Merge(&node_segments[dst->id()]);
+
+        // Contracting the edge leaves disconnected graph edges.
+        // Remove these from the graph and from 'contract_edges' so we
+        // don't visit them again.
+        tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
+        std::vector<const tensorflow::Edge*> remove_edges;
+        ContractEdge(e, &graph, &remove_edges);
+
+        for (const tensorflow::Edge* r : remove_edges) {
+          contract_edges.erase(r);
+          graph.RemoveEdge(r);
+        }
+      }
+    }
+  }
+
+  // Collect the segments/subgraphs. Each subgraph is represented by a
+  // set of the names of the nodes in that subgraph.
+  std::unordered_map<std::string, std::set<std::string>> sg_map;
+  for (auto& u : node_segments) {
+    if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
+      sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+    }
+  }
+
+  // Cleanup the graph to remove disconnected nodes before outputting
+  if (VLOG_IS_ON(2)) {
+    for (tensorflow::Node* node : graph.nodes()) {
+      if ((node->in_edges().size() == 0) && (node->out_edges().size() == 0)) {
+        graph.RemoveNode(node);
+      }
+    }
+    // tensorflow::DumpGraph("Post-Segment", &graph);
+  }
+
+  // Convert the segments into the expected return format
+  for (const auto& itr : sg_map) {
+    const auto& segment_node_names = itr.second;
+    if (VLOG_IS_ON(1)) {
+      std::string s;
+      for (const auto& name : segment_node_names) {
+        s += " " + name;
+      }
+      VLOG(1) << "Segment " << segments->size() << ":" << s;
+    }
+
+    // Don't use small segments.
+    if (static_cast<int>(segment_node_names.size()) <
+        options.minimum_segment_size) {
+      VLOG(1) << "Segment " << segments->size() << " has only "
+              << segment_node_names.size() << " nodes, dropping";
+      continue;
+    }
+
+    segments->emplace_back(segment_node_names);
+  }
+
+  return tensorflow::Status::OK();
+}
+
+}  // namespace segment
+}  // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
new file mode 100644
index 0000000..b5aee5b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+
+#include <set>
+#include <vector>
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorrt {
+namespace segment {
+
+using SegmentNodesVector = std::vector<std::set<std::string>>;
+
+struct SegmentOptions {
+  // Segment must contain at least this many nodes.
+  int minimum_segment_size = 2;
+};
+
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param gdef The GraphDef describing the network
+// @param candidate_fn A function that returns true for a NodeDef if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+    const tensorflow::GraphDef& gdef,
+    const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+    const SegmentOptions& options, SegmentNodesVector* segments);
+
+}  // namespace segment
+}  // namespace tensorrt
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
new file mode 100644
index 0000000..dcd0c71
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -0,0 +1,363 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+
+//------------------------------------------------------------------------------
+using namespace tensorflow;
+
+namespace tensorrt {
+namespace segment {
+namespace test {
+
+class SegmentTest : public ::testing::Test {
+ public:
+  bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
+
+  TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
+  TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+                    TF_Status* s, const char* name);
+
+  std::function<bool(const NodeDef&)> MakeCandidateFn(
+      const std::set<std::string>& node_names);
+
+ protected:
+  void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
+                         TF_Operation** op);
+  void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+                 TF_Status* s, const char* name, TF_Operation** op, bool check);
+
+  SegmentOptions default_options_;
+};
+
+bool SegmentTest::GetGraphDef(TF_Graph* graph,
+                              tensorflow::GraphDef* graph_def) {
+  TF_Status* s = TF_NewStatus();
+  TF_Buffer* buffer = TF_NewBuffer();
+  TF_GraphToGraphDef(graph, buffer, s);
+  bool ret = TF_GetCode(s) == TF_OK;
+  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
+  TF_DeleteBuffer(buffer);
+  TF_DeleteStatus(s);
+  return ret;
+}
+
+std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
+    const std::set<std::string>& node_names) {
+  return [node_names](const NodeDef& node) -> bool {
+    return node_names.find(node.name()) != node_names.end();
+  };
+}
+
+void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
+                                    const char* name, TF_Operation** op) {
+  TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
+  TF_SetAttrType(desc, "dtype", TF_INT32);
+  *op = TF_FinishOperation(desc, s);
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
+                                       const char* name) {
+  TF_Operation* op;
+  PlaceholderHelper(graph, s, name, &op);
+  return op;
+}
+
+void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+                            TF_Status* s, const char* name, TF_Operation** op,
+                            bool check) {
+  TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+  TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+  TF_AddInputList(desc, add_inputs, 2);
+  *op = TF_FinishOperation(desc, s);
+  if (check) {
+    ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+    ASSERT_NE(*op, nullptr);
+  }
+}
+
+TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
+                               TF_Graph* graph, TF_Status* s,
+                               const char* name) {
+  TF_Operation* op;
+  AddHelper(l, r, graph, s, name, &op, true);
+  return op;
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, Empty) {
+  TF_Graph* graph = TF_NewGraph();
+
+  GraphDef graph_def;
+  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+  SegmentNodesVector segments;
+  ASSERT_EQ(
+      SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
+      tensorflow::Status::OK());
+
+  // Expect no segments/subgraphs.
+  EXPECT_TRUE(segments.empty());
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, Simple) {
+  TF_Status* s = TF_NewStatus();
+  TF_Graph* graph = TF_NewGraph();
+
+  //           feed
+  //         //    ||
+  //       add0    add1
+  //        | |    /
+  //        |  add2
+  //        |  /  ||
+  //       add3    add4
+  //           |  /
+  //          <sink>
+  //
+  TF_Operation* feed = Placeholder(graph, s, "feed");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+  TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+  GraphDef graph_def;
+  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+  SegmentNodesVector segments;
+  ASSERT_EQ(
+      SegmentGraph(graph_def,
+                   MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
+                   default_options_, &segments),
+      tensorflow::Status::OK());
+
+  // Expect all Add operations to be collapsed into a single segment
+  ASSERT_EQ(segments.size(), 1);
+  std::vector<std::string> expected{"add0", "add1", "add2", "add3", "add4"};
+  for (const auto& ex : expected) {
+    EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+        << "Missing expected node " << ex;
+  }
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, AvoidCycle) {
+  TF_Status* s = TF_NewStatus();
+  TF_Graph* graph = TF_NewGraph();
+
+  // add2 is not a TRT candidate so add0/add3 cannot be formed as a
+  // subgraph
+  //
+  //           feed
+  //         //    ||
+  //       add0    add1
+  //        | |    /
+  //        |  add2
+  //        |  /  ||
+  //       add3    add4
+  //           |  /
+  //          <sink>
+  //
+  TF_Operation* feed = Placeholder(graph, s, "feed");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+  TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+  GraphDef graph_def;
+  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+  SegmentNodesVector segments;
+  ASSERT_EQ(
+      SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
+                   default_options_, &segments),
+      tensorflow::Status::OK());
+
+  // Expect no subgraphs
+  EXPECT_EQ(segments.size(), 0);
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, Multiple) {
+  TF_Status* s = TF_NewStatus();
+  TF_Graph* graph = TF_NewGraph();
+
+  // add5 is not a TRT candidate so two subgraphs should be formed
+  //
+  //                feed
+  //         //      ||     ||
+  //       add0    add1      add7
+  //        | |    /        /   ||
+  //        |  add2-----add5    add8
+  //        |  /  |    |  |    |
+  //       add3   add4     add6
+  //           |     |     /
+  //               <sink>
+  //
+  TF_Operation* feed = Placeholder(graph, s, "feed");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+  TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+  TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
+
+  GraphDef graph_def;
+  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+  SegmentNodesVector segments;
+  ASSERT_EQ(SegmentGraph(graph_def,
+                         MakeCandidateFn({"add0", "add1", "add2", "add3",
+                                          "add4", "add6", "add7", "add8"}),
+                         default_options_, &segments),
+            tensorflow::Status::OK());
+
+  // Expect two subgraphs
+  EXPECT_EQ(segments.size(), 2);
+
+  std::vector<std::string> expected0{"add0", "add1", "add2", "add3"};
+  for (const auto& ex : expected0) {
+    EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+        << "Missing expected node " << ex;
+  }
+
+  std::vector<std::string> expected1{"add6", "add8"};
+  for (const auto& ex : expected1) {
+    EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+        << "Missing expected node " << ex;
+  }
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, BigIfElse) {
+  TF_Status* s = TF_NewStatus();
+  TF_Graph* graph = TF_NewGraph();
+
+  // add2 is not a TRT candidate
+  //
+  //           feed
+  //            ||
+  //           add0
+  //         //    ||
+  //       add1    add4
+  //        ||      ||
+  //       add2    add5
+  //        ||      ||
+  //       add3    add6
+  //         ||    //
+  //           add7
+  //            ||
+  //          <sink>
+  //
+  TF_Operation* feed = Placeholder(graph, s, "feed");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+  TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
+  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+  EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
+
+  GraphDef graph_def;
+  ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+  SegmentNodesVector segments;
+  ASSERT_EQ(SegmentGraph(graph_def,
+                         MakeCandidateFn({"add0", "add1", "add3", "add4",
+                                          "add5", "add6", "add7"}),
+                         default_options_, &segments),
+            tensorflow::Status::OK());
+
+  // Expect 2 subgraphs
+  EXPECT_EQ(segments.size(), 2);
+
+  std::vector<std::string> expected0{"add3", "add4", "add5", "add6", "add7"};
+  for (const auto& ex : expected0) {
+    EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+        << "Missing expected node " << ex;
+  }
+
+  std::vector<std::string> expected1{"add0", "add1"};
+  for (const auto& ex : expected1) {
+    EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+        << "Missing expected node " << ex;
+  }
+}
+
+}  // namespace test
+}  // namespace segment
+}  // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/contrib/tensorrt/segment/union_find.h
new file mode 100644
index 0000000..8ae877c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/union_find.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+
+namespace tensorrt {
+namespace segment {
+
+// Union-Find data structure.
+// Each cluster has an associated value; when merging clusters we can control
+// which value becomes the representative of the merged clusters. Values must be
+// copyable.
+template <typename T>
+class UnionFind {
+ public:
+  UnionFind() : size_(1), parent_(nullptr) {}
+  explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
+
+  // Returns the number of elements in a cluster.
+  int Size() { return FindRoot()->size_; }
+
+  // Merges this cluster with 'other'. This cluster's value becomes
+  // the value of the merged cluster; the value of 'other' is ignored.
+  void Merge(UnionFind* other);
+
+  // Each cluster has an associated value. Retrieves the value associated
+  // with this cluster.
+  T& ParentValue() { return FindRoot()->value_; }
+
+  // Get the original value of this node.
+  T& Value() { return value_; }
+
+ private:
+  // Finds the root element of the cluster. Performs path compression.
+  UnionFind* FindRoot();
+
+  int size_;
+  UnionFind* parent_;
+  T value_;
+};
+
+template <typename T>
+void UnionFind<T>::Merge(UnionFind* other) {
+  UnionFind<T>* a = FindRoot();
+  UnionFind<T>* b = other->FindRoot();
+  if (a == b) return;
+
+  b->parent_ = a;
+  a->size_ += b->size_;
+}
+
+template <typename T>
+UnionFind<T>* UnionFind<T>::FindRoot() {
+  if (!parent_) return this;
+  // Path compression: update intermediate nodes to point to the root of the
+  // equivalence class.
+  parent_ = parent_->FindRoot();
+  return parent_;
+}
+
+}  // namespace segment
+}  // namespace tensorrt
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
new file mode 100644
index 0000000..72022b9
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+#include <string>
+#include <vector>
+#include "NvInfer.h"
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+
+namespace tensorflow {
+namespace shape_inference {
+tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
+  tensorflow::tensorrt::Logger gLogger;
+  string serialized_engine;
+  c->GetAttr("serialized_engine", &serialized_engine);
+  nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
+  nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
+      serialized_engine.c_str(), serialized_engine.size(), nullptr);
+
+  // debug print out engine binding;
+  std::stringstream oss;
+  for (int i = 0; i < trt_engine->getNbBindings(); i++) {
+    LOG(INFO) << "index: " << i
+              << ", binding name: " << trt_engine->getBindingName(i);
+
+    bool input_flag = trt_engine->bindingIsInput(i);
+    oss << "input?: " << (input_flag ? "Y" : "N");
+
+    oss << "Dimension: ";
+    auto dims = trt_engine->getBindingDimensions(i);
+    oss << " nbDims: " << dims.nbDims << " -> ";
+    for (int j = 0; j < dims.nbDims; j++) oss << dims.d[j] << ", ";
+    LOG(INFO) << oss.str();
+    oss.str("");
+    switch (trt_engine->getBindingDataType(i)) {
+      case nvinfer1::DataType::kFLOAT:
+        LOG(INFO) << "data type: float" << std::endl;
+        break;
+      case nvinfer1::DataType::kHALF:
+        LOG(INFO) << "data type: half" << std::endl;
+        break;
+      case nvinfer1::DataType::kINT8:
+        LOG(INFO) << "data type: int8" << std::endl;
+        break;
+    }
+  }
+
+  int nbBatch = -1;
+  // debug print out input arrays
+  std::vector<::tensorflow::DataType> input_type;
+  c->GetAttr("InT", &input_type);
+  oss.str("");
+  for (size_t i = 0; i < c->num_inputs(); i++) {
+    // check if input shape is legit
+    auto input_shape = c->input(i);
+    int index = i;
+    oss << "input:" << i << " type: " << input_type[index] << " shape: ";
+    for (int j = 0; j < c->Rank(input_shape); j++) {
+      auto dimHandler = c->Dim(input_shape, j);
+      if (c->ValueKnown(dimHandler))
+        oss << c->Value(dimHandler) << ", ";
+      else
+        oss << "?" << c->Value(dimHandler) << ", ";
+      if (j == 0) {
+        if (i == 0)
+          nbBatch = c->Value(dimHandler);
+        else if (nbBatch != c->Value(dimHandler))
+          LOG(WARNING) << "!!!!!!nbBatch does not match!!!!!!";
+        // assert(nbBatch == c->Value(dimHandler);
+      }
+    }
+    LOG(INFO) << oss.str();
+  }
+
+  // arrange input here
+  std::vector<string> input_nodes;
+  c->GetAttr("input_nodes", &input_nodes);
+  for (size_t i = 0; i < input_nodes.size(); i++) {
+    int index = i;
+    LOG(INFO) << "input:" << i << " name: " << input_nodes[index];
+  }
+
+  // arrange output here
+  std::vector<string> output_nodes;
+  c->GetAttr("output_nodes", &output_nodes);
+  oss.str("");
+  for (size_t i = 0; i < output_nodes.size(); i++) {
+    int index = i;
+    int binding_index =
+        trt_engine->getBindingIndex(output_nodes[index].c_str());
+    oss << "string name " << output_nodes[index];
+    ShapeHandle output_shape;
+    std::vector<DimensionHandle> vecDim;
+    vecDim.emplace_back(c->MakeDim(nbBatch));
+    if (binding_index != -1) {
+      oss << "got binding " << binding_index;
+      auto dims = trt_engine->getBindingDimensions(binding_index);
+      for (int j = 0; j < dims.nbDims; j++)
+        vecDim.emplace_back(c->MakeDim(dims.d[j]));
+    } else {
+      oss << "no binding ";
+    }
+    output_shape = c->MakeShape(vecDim);
+    c->set_output(i, output_shape);
+    LOG(INFO) << oss.str();
+  }
+
+  return Status::OK();
+}
+}  // namespace shape_inference
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
new file mode 100644
index 0000000..90a226d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace shape_inference {
+Status TRTEngineOpShapeInference(InferenceContext* c);
+}  // namespace shape_inference
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
new file mode 100644
index 0000000..5f8e73a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -0,0 +1,84 @@
+/*
+
+  wrap trt_conversion
+
+ */
+%{
+#define SWIG_FILE_WITH_INIT
+%}
+%include "std_string.i"
+%include "std_pair.i"
+%include "tensorflow/python/lib/core/strings.i"
+%include "tensorflow/python/platform/base.i"
+%template(StringPair) std::pair<string,string>;
+%template() std::pair<swig::SwigPtr_PyObject, swig::SwigPtr_PyObject>;
+
+%{
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/stat_summarizer.h"
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+%}
+
+%ignoreall
+%unignore tensorflow;
+%unignore trt_convert;
+
+%{
+  std::pair<string,string> trt_convert(string graph_def_string,//const tensorflow::GraphDef&
+		   std::vector<string> output_names,
+		   size_t max_batch_size,
+		   size_t max_workspace_size
+		   // unfortunately we can't use TF_Status here since it
+		   // is in c/c_api and brings in a lot of other libraries
+		   // which in turn declare ops. These ops are included
+		   // statically in our library and cause an abort when
+		   // module is loaded due to double registration
+		   // until Tensorflow properly exposes these headers
+		   // we have to work around this by returning a string
+		   // and converting it to exception on python side.
+		   //,TF_Status* out_status) {
+		   ) {
+    string out_status;
+
+    tensorflow::GraphDef graph_def;
+    if (!graph_def.ParseFromString(graph_def_string)) {
+      out_status="InvalidArgument;Couldn't interpret input as a GraphDef";
+      return std::pair<string,string>{out_status,""};
+    }
+
+    if (!output_names.size()) {
+      out_status="InvalidArgument;Size of the output_names vector is 0";
+      return std::pair<string,string>{out_status,""};
+      //return "";
+    }
+    tensorflow::GraphDef outGraph;
+    tensorflow::Status conversion_status =
+      tensorrt::convert::ConvertGraphDefToTensorRT(graph_def,
+						   output_names,
+						   max_batch_size,
+						   max_workspace_size,
+						   &outGraph);
+    if (!conversion_status.ok()) {
+      auto retCode=(int)conversion_status.code();
+      char buff[2000];
+      snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str());
+      out_status=buff;
+      return std::pair<string,string>{out_status,""};
+    }
+    string result;
+    if (!outGraph.SerializeToString(&result)) {
+      out_status="InvalidArgument;Couldn't serialize output as a GraphDef";
+      return std::pair<string,string>{out_status,""};
+    }
+    out_status="OK;All good!";
+    return std::pair<string,string>{out_status,result};
+  }
+%}
+
+std::pair<string,string> trt_convert(string graph_def_string,
+				     std::vector<string> output_names,
+				     size_t max_batch_size,
+				     size_t max_workspace_size);
+
+%unignoreall
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 383c973..838b121 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -279,7 +279,7 @@
     linkopts=[],
     framework_so=tf_binary_additional_srcs(),
     **kwargs):
-  native.cc_binary(
+    native.cc_binary(
       name=name,
       srcs=srcs + framework_so,
       deps=deps,
@@ -1281,6 +1281,45 @@
 def tf_extension_copts():
   return []  # No extension c opts
 
+# In tf_py_wrap_cc generated libraries
+# module init functions are not exported unless
+# they contain one of the keywords in the version file
+# this prevents custom python modules.
+# This function attempts to append init_module_name to list of
+# exported functions in version script
+def _append_init_to_versionscript_impl(ctx):
+    modName=ctx.attr.module_name
+    isVS=ctx.attr.is_version_script
+    if isVS:
+        ctx.actions.expand_template(
+            template=ctx.file.template_file,
+            output=ctx.outputs.versionscript,
+            substitutions={
+                "global:":"global:\n   init_%s;"%modName,
+            },
+            is_executable=False,
+        )
+    else:
+        ctx.actions.expand_template(
+            template=ctx.file.template_file,
+            output=ctx.outputs.versionscript,
+            substitutions={
+                "*tensorflow*":"*tensorflow*\ninit_%s"%modName,
+            },
+            is_executable=False,
+        )
+
+
+_append_init_to_versionscript= rule(
+    implementation=_append_init_to_versionscript_impl,
+    attrs={
+        "module_name":attr.string(mandatory=True),
+        "template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
+        "is_version_script":attr.bool(default=True,doc='whether target is a ld version script or exported symbol list',mandatory=False),
+    },
+    outputs={"versionscript":"%{name}.lds"},
+)
+
 def tf_py_wrap_cc(name,
                              srcs,
                              swig_includes=[],
@@ -1302,26 +1341,39 @@
       toolchain_deps=["//tools/defaults:crosstool"],
       module_name=module_name,
       py_module_name=name)
+  vscriptname=name+"_versionscript"
+  _append_init_to_versionscript(
+      name=vscriptname,
+      module_name=module_name,
+      is_version_script=select({
+          "@local_config_cuda//cuda:darwin":False,
+          "//conditions:default":True,
+          }),
+      template_file=select({
+          "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
+          "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
+      })
+  )
   extra_linkopts = select({
       "@local_config_cuda//cuda:darwin": [
           "-Wl,-exported_symbols_list",
-          clean_dep("//tensorflow:tf_exported_symbols.lds")
+          "%s.lds"%vscriptname,
       ],
       clean_dep("//tensorflow:windows"): [],
       clean_dep("//tensorflow:windows_msvc"): [],
       "//conditions:default": [
           "-Wl,--version-script",
-          clean_dep("//tensorflow:tf_version_script.lds")
+          "%s.lds"%vscriptname,
       ]
   })
   extra_deps += select({
       "@local_config_cuda//cuda:darwin": [
-          clean_dep("//tensorflow:tf_exported_symbols.lds")
+          "%s.lds"%vscriptname,
       ],
       clean_dep("//tensorflow:windows"): [],
       clean_dep("//tensorflow:windows_msvc"): [],
       "//conditions:default": [
-          clean_dep("//tensorflow:tf_version_script.lds")
+          "%s.lds"%vscriptname,
       ]
   })
 
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index ff5dd6a..f47df0e 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -11,6 +11,7 @@
 )
 load("//third_party/mkl:build_defs.bzl", "if_mkl")
 load("//tensorflow:tensorflow.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
 load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
 
 # This returns a list of headers of all public header libraries (e.g.,
@@ -201,7 +202,8 @@
             "//tensorflow/python:test_ops",
             "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
         ],
-    }) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
+    }) + if_mkl(["//third_party/mkl:intel_binary_blob"])
+    + if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
 )
 
 # A genrule for generating a marker file for the pip package on Windows
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 0ba3cca..8850610 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,6 +1,7 @@
 # TensorFlow external dependencies that can be loaded in WORKSPACE files.
 
 load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+load("//third_party/tensorrt:build_defs.bzl", "trt_repository")
 load("//third_party/mkl:build_defs.bzl", "mkl_repository")
 load("//third_party/git:git_configure.bzl", "git_configure")
 load("//third_party/py:python_configure.bzl", "python_configure")
@@ -66,6 +67,7 @@
   # version we require here.
   check_bazel_version_at_least("0.5.4")
   cuda_configure(name="local_config_cuda")
+  trt_repository(name="local_config_tensorrt")
   git_configure(name="local_config_git")
   sycl_configure(name="local_config_sycl")
   python_configure(name="local_config_python")
diff --git a/third_party/tensorrt/BUILD b/third_party/tensorrt/BUILD
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/third_party/tensorrt/BUILD
diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl
new file mode 100644
index 0000000..8962751
--- /dev/null
+++ b/third_party/tensorrt/BUILD.tpl
@@ -0,0 +1,42 @@
+# -*- python -*-
+# Description:
+#   provide tensorrt information
+
+#TODO(Sami) these needs to be defined 
+
+licenses(["notice"])  
+
+exports_files(["LICENSE"])
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
+
+config_setting(
+    name = "trt_enabled",
+    define_values = {
+        "using_tensorrt":"true"
+    },
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
+    name = "tensorrt",
+    srcs =[%{tensorrt_lib}],
+    hdrs = ["include/NvInfer.h",
+            "include/NvUtils.h",
+    ],
+    copts= cuda_default_copts(),
+    deps =["@local_config_cuda//cuda:cuda",
+	   "@local_config_cuda//cuda:cudnn",],
+    linkstatic = 1,
+    #include_prefix="include/",
+    includes=["include/"],
+    visibility = ["//visibility:public"],	
+)
+
+%{tensorrt_genrules}
+
+# filegroup(
+#     name = "%{tensorrt_lib}",
+#     srcs =  ["%{tensorrt_lib}"],
+#     visibility = ["//visibility:public"],
+# )
diff --git a/third_party/tensorrt/LICENSE b/third_party/tensorrt/LICENSE
new file mode 100644
index 0000000..d3da228
--- /dev/null
+++ b/third_party/tensorrt/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2015 The TensorFlow Authors.  All rights reserved.
+
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright 2015, The TensorFlow Authors.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/third_party/tensorrt/build_defs.bzl b/third_party/tensorrt/build_defs.bzl
new file mode 100644
index 0000000..392c5e0
--- /dev/null
+++ b/third_party/tensorrt/build_defs.bzl
@@ -0,0 +1,85 @@
+# -*- python -*-
+"""
+ add a repo_generator rule for tensorrt
+
+"""
+
+_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
+_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"
+
+def _is_trt_enabled(repo_ctx):
+    if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
+        enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
+        return enable_trt == "1"
+    return False
+
+def _dummy_repo(repo_ctx):
+
+    repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
+                      {"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
+                      False)
+    repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
+                      {"%{trt_configured}":"False"},False)
+    repo_ctx.file("include/NvUtils.h","",False)
+    repo_ctx.file("include/NvInfer.h","",False)
+
+def _trt_repo_impl(repo_ctx):
+    """
+    Implements local_config_tensorrt
+    """
+
+    if not _is_trt_enabled(repo_ctx):
+        _dummy_repo(repo_ctx)
+        return
+    trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
+    trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
+# if deb installation
+# once a standardized installation between tar and deb
+# is done, we don't need this
+    if trt_libdir == '/usr/lib/x86_64-linux-gnu':
+        incPath='/usr/include/x86_64-linux-gnu'
+        incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
+    else:
+        incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
+        incname=incPath+'/NvInfer.h'
+    if len(trt_ver)>0:
+        origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
+    else:
+        origLib="%s/libnvinfer.so"%trt_libdir        
+    objdump=repo_ctx.which("objdump")
+    if objdump == None:
+        if len(trt_ver)>0:
+            targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
+        else:
+            targetlib="lib/libnvinfer.so"
+    else:
+        soname=repo_ctx.execute([objdump,"-p",origLib])
+        for l in soname.stdout.splitlines():
+            if "SONAME" in l:
+                lib=l.strip().split(" ")[-1]
+                targetlib="lib/%s"%(lib)
+    
+    if len(trt_ver)>0:
+        repo_ctx.symlink(origLib,targetlib)
+    else:
+        repo_ctx.symlink(origLib,targetlib)
+    grule=('genrule(\n    name = "trtlinks",\n'+
+           '    outs = [\n    "%s",\n    "include/NvInfer.h",\n    "include/NvUtils.h",\n     ],\n'%targetlib +
+           '    cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
+           '&&\n    ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
+           '&&\n    ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
+    repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
+                      {"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
+                      False)
+    repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
+                      {"%{trt_configured}":"True"},False)
+
+trt_repository=repository_rule(
+    implementation= _trt_repo_impl,
+    local=True,
+    environ=[
+        "TF_NEED_TENSORRT",
+        _TF_TENSORRT_VERSION,
+        _TENSORRT_INSTALLATION_PATH,
+        ],
+    )
diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl
new file mode 100644
index 0000000..18f354e
--- /dev/null
+++ b/third_party/tensorrt/build_defs.bzl.tpl
@@ -0,0 +1,18 @@
+# -*- python -*-
+"""
+template file for trt functions
+
+"""
+
+def is_trt_enabled():
+    return %{trt_configured}
+
+def if_trt(if_true,if_false=[]):
+    # if is_trt_enabled():
+    #     return if_true
+    # return if_false
+
+    return select({
+        "@local_config_tensorrt//:trt_enabled":if_true,
+        "//conditions:default":if_false,
+    })