Fix Py3 byte and string issue after swig update. Clarify failure message on finding libnvinfer in configure.py
diff --git a/configure.py b/configure.py
index 9cf5902..3aa1a3e 100644
--- a/configure.py
+++ b/configure.py
@@ -1078,12 +1078,21 @@
break
# Reset and Retry
- print('Invalid path to TensorRT. None of the following files can be found:')
- print(trt_install_path)
- print(os.path.join(trt_install_path, 'lib'))
- print(os.path.join(trt_install_path, 'lib64'))
- if search_result:
- print(libnvinfer_path_from_ldconfig)
+ if len(possible_files):
+ print('TensorRT libraries found in one the following directories',
+ 'are not compatible with selected cuda and cudnn installations')
+ print(trt_install_path)
+ print(os.path.join(trt_install_path, 'lib'))
+ print(os.path.join(trt_install_path, 'lib64'))
+ if search_result:
+ print(libnvinfer_path_from_ldconfig)
+ else:
+ print('Invalid path to TensorRT. None of the following files can be found:')
+ print(trt_install_path)
+ print(os.path.join(trt_install_path, 'lib'))
+ print(os.path.join(trt_install_path, 'lib64'))
+ if search_result:
+ print(libnvinfer_path_from_ldconfig)
else:
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 69bbf45..9454862 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -55,10 +55,18 @@
def py3bytes(inp):
return inp.encode("utf-8", errors="surrogateescape")
+ def py2string(inp):
+ return inp
+
+ def py3string(inp):
+ return inp.decode("utf-8")
+
if _six.PY2:
to_bytes = py2bytes
+ to_string = py2string
else:
to_bytes = py3bytes
+ to_string = py3string
out_names = []
for i in outputs:
@@ -76,8 +84,8 @@
# one is the transformed graphs protobuf string.
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
max_workspace_size_bytes)
- status = out[0]
- output_graph_def_string = to_bytes(out[1])
+ status = to_string(out[0])
+ output_graph_def_string = out[1]
del input_graph_def_str # Save some memory
if len(status) < 2:
raise _impl.UnknownError(None, None, status)
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 927a3e4..adf3438 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -67,5 +67,7 @@
inpDims[0]) # Get optimized graph
o1 = runGraph(gdef, dummy_input)
o2 = runGraph(trt_graph, dummy_input)
+ o3 = runGraph(trt_graph, dummy_input)
assert (np.array_equal(o1, o2))
+ assert (np.array_equal(o2, o3))
print("Pass")