[XLA:Python] Add bindings for xla::PrecisionConfig.
PiperOrigin-RevId: 254792785
diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h
index 2d0f7c6..bc0ee2b 100644
--- a/tensorflow/compiler/xla/python/types.h
+++ b/tensorflow/compiler/xla/python/types.h
@@ -424,6 +424,29 @@
return true;
}
};
+
+template <>
+struct type_caster<xla::PrecisionConfig> {
+ public:
+ PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
+
+ // PyObject -> C++ conversion.
+ bool load(handle handle, bool) {
+ if (handle.is_none()) {
+ return true;
+ }
+
+ sequence operand_precisions =
+ reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
+
+ for (auto operand_precision : operand_precisions) {
+ value.add_operand_precision(
+ operand_precision.cast<xla::PrecisionConfig::Precision>());
+ }
+ return true;
+ }
+};
+
} // namespace detail
} // namespace pybind11
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index 40bf429..23eec49 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -594,9 +594,13 @@
.value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
.value("ADJOINT", TriangularSolveOptions::ADJOINT);
+ py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
+ .value("DEFAULT", PrecisionConfig::DEFAULT)
+ .value("HIGH", PrecisionConfig::HIGH)
+ .value("HIGHEST", PrecisionConfig::HIGHEST);
+
// TODO(phawkins): improve bindings for these types.
py::class_<ChannelHandle>(m, "ChannelHandle");
- py::class_<PrecisionConfig>(m, "PrecisionConfig");
tensorflow::AddXrtSubmodule(&m);
}
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a86400c..a2e7fc2 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1256,7 +1256,7 @@
"""
return ops.BuildConstantSubGraph(operand)
- def DotGeneral(self, lhs, rhs, dimension_numbers):
+ def DotGeneral(self, lhs, rhs, dimension_numbers, precision_config=None):
"""Enqueues a general dot operation onto the computation.
Args:
@@ -1270,10 +1270,17 @@
"""
if isinstance(dimension_numbers, tuple):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
- return ops.DotGeneral(lhs, rhs, dimension_numbers)
+ return ops.DotGeneral(
+ lhs, rhs, dimension_numbers, precision_config=precision_config)
- def Conv(self, lhs, rhs, window_strides, padding,
- feature_group_count=1, batch_group_count=1):
+ def Conv(self,
+ lhs,
+ rhs,
+ window_strides,
+ padding,
+ feature_group_count=1,
+ batch_group_count=1,
+ precision_config=None):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1296,7 +1303,8 @@
pads, [], [],
dimension_numbers=None,
feature_group_count=feature_group_count,
- batch_group_count=batch_group_count)
+ batch_group_count=batch_group_count,
+ precision_config=precision_config)
def ConvWithGeneralPadding(self,
lhs,
@@ -1306,7 +1314,8 @@
lhs_dilation,
rhs_dilation,
feature_group_count=1,
- batch_group_count=1):
+ batch_group_count=1,
+ precision_config=None):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1331,7 +1340,8 @@
list(rhs_dilation),
dimension_numbers=None,
feature_group_count=feature_group_count,
- batch_group_count=batch_group_count)
+ batch_group_count=batch_group_count,
+ precision_config=precision_config)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1357,7 +1367,8 @@
rhs_dilation,
dimension_numbers=None,
feature_group_count=1,
- batch_group_count=1):
+ batch_group_count=1,
+ precision_config=None):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1411,9 +1422,17 @@
dimension_numbers.output_spatial_dimensions.extend(
sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
key=lambda i: rhs_spec.index(out_spec[i])))
- return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers,
- feature_group_count, batch_group_count)
+ return ops.ConvGeneralDilated(
+ lhs,
+ rhs,
+ window_strides,
+ padding,
+ lhs_dilation,
+ rhs_dilation,
+ dimension_numbers,
+ feature_group_count,
+ batch_group_count,
+ precision_config=precision_config)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
@@ -1657,6 +1676,16 @@
self.output_spatial_dimensions = []
+class PrecisionConfig(object):
+ """Python representation of a xla.PrecisionConfig protobuf."""
+ __slots__ = ('operand_precision',)
+
+ Precision = _xla.PrecisionConfig_Precision
+
+ def __init__(self):
+ self.operand_precision = []
+
+
class GatherDimensionNumbers(object):
"""Python representation of a xla.GatherDimensionNumbers protobuf."""
__slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 4a90cc3..e6cfd46 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -716,6 +716,22 @@
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+ def testDotGeneralWithPrecisionConfig(self):
+ c = self._NewComputation()
+ rng = np.random.RandomState(0)
+ lhs = NumpyArrayF32(rng.randn(10, 3, 4))
+ rhs = NumpyArrayF32(rng.randn(10, 4, 5))
+ dimension_numbers = (([2], [1]), ([0], [0]))
+ config = xla_client.PrecisionConfig()
+ config.operand_precision.append(config.Precision.HIGH)
+ config.operand_precision.append(config.Precision.HIGHEST)
+ c.DotGeneral(
+ c.Constant(lhs),
+ c.Constant(rhs),
+ dimension_numbers,
+ precision_config=config)
+ self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+
def testConvF32Same(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
@@ -784,6 +800,36 @@
]]])
self._ExecuteAndCompareClose(c, expected=result)
+ def testConvGeneralDilatedF32WithPrecisionConfig(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 1, 2, 3)
+ rhs = a(1, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ config = xla_client.PrecisionConfig()
+ config.operand_precision.append(config.Precision.HIGHEST)
+ config.operand_precision.append(config.Precision.DEFAULT)
+ c.ConvGeneralDilated(
+ c.Constant(lhs),
+ c.Constant(rhs),
+ strides,
+ pads,
+ lhs_dilation,
+ rhs_dilation,
+ dimension_numbers,
+ precision_config=config)
+ result = np.array([[[
+ [0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.],
+ ]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testConvGeneralDilatedPermutedF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")