[XLA:Python] Add bindings for xla::PrecisionConfig.

PiperOrigin-RevId: 254792785
diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h
index 2d0f7c6..bc0ee2b 100644
--- a/tensorflow/compiler/xla/python/types.h
+++ b/tensorflow/compiler/xla/python/types.h
@@ -424,6 +424,29 @@
     return true;
   }
 };
+
+template <>
+struct type_caster<xla::PrecisionConfig> {
+ public:
+  PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
+
+  // PyObject -> C++ conversion.
+  bool load(handle handle, bool) {
+    if (handle.is_none()) {
+      return true;
+    }
+
+    sequence operand_precisions =
+        reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
+
+    for (auto operand_precision : operand_precisions) {
+      value.add_operand_precision(
+          operand_precision.cast<xla::PrecisionConfig::Precision>());
+    }
+    return true;
+  }
+};
+
 }  // namespace detail
 }  // namespace pybind11
 
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index 40bf429..23eec49 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -594,9 +594,13 @@
       .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
       .value("ADJOINT", TriangularSolveOptions::ADJOINT);
 
+  py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
+      .value("DEFAULT", PrecisionConfig::DEFAULT)
+      .value("HIGH", PrecisionConfig::HIGH)
+      .value("HIGHEST", PrecisionConfig::HIGHEST);
+
   // TODO(phawkins): improve bindings for these types.
   py::class_<ChannelHandle>(m, "ChannelHandle");
-  py::class_<PrecisionConfig>(m, "PrecisionConfig");
 
   tensorflow::AddXrtSubmodule(&m);
 }
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a86400c..a2e7fc2 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1256,7 +1256,7 @@
     """
     return ops.BuildConstantSubGraph(operand)
 
-  def DotGeneral(self, lhs, rhs, dimension_numbers):
+  def DotGeneral(self, lhs, rhs, dimension_numbers, precision_config=None):
     """Enqueues a general dot operation onto the computation.
 
     Args:
@@ -1270,10 +1270,17 @@
     """
     if isinstance(dimension_numbers, tuple):
       dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
-    return ops.DotGeneral(lhs, rhs, dimension_numbers)
+    return ops.DotGeneral(
+        lhs, rhs, dimension_numbers, precision_config=precision_config)
 
-  def Conv(self, lhs, rhs, window_strides, padding,
-           feature_group_count=1, batch_group_count=1):
+  def Conv(self,
+           lhs,
+           rhs,
+           window_strides,
+           padding,
+           feature_group_count=1,
+           batch_group_count=1,
+           precision_config=None):
     """Enqueues a Conv operation onto the computation.
 
     Args:
@@ -1296,7 +1303,8 @@
         pads, [], [],
         dimension_numbers=None,
         feature_group_count=feature_group_count,
-        batch_group_count=batch_group_count)
+        batch_group_count=batch_group_count,
+        precision_config=precision_config)
 
   def ConvWithGeneralPadding(self,
                              lhs,
@@ -1306,7 +1314,8 @@
                              lhs_dilation,
                              rhs_dilation,
                              feature_group_count=1,
-                             batch_group_count=1):
+                             batch_group_count=1,
+                             precision_config=None):
     """Enqueues a ConvWithGeneralPadding operation onto the computation.
 
     Args:
@@ -1331,7 +1340,8 @@
         list(rhs_dilation),
         dimension_numbers=None,
         feature_group_count=feature_group_count,
-        batch_group_count=batch_group_count)
+        batch_group_count=batch_group_count,
+        precision_config=precision_config)
 
   def _GetConvDimensionNumbers(self, num_spatial_dims):
     """Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1357,7 +1367,8 @@
                          rhs_dilation,
                          dimension_numbers=None,
                          feature_group_count=1,
-                         batch_group_count=1):
+                         batch_group_count=1,
+                         precision_config=None):
     """Enqueues a ConvGeneralDilated operation onto the computation.
 
     Args:
@@ -1411,9 +1422,17 @@
       dimension_numbers.output_spatial_dimensions.extend(
           sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
                  key=lambda i: rhs_spec.index(out_spec[i])))
-    return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding,
-                                  lhs_dilation, rhs_dilation, dimension_numbers,
-                                  feature_group_count, batch_group_count)
+    return ops.ConvGeneralDilated(
+        lhs,
+        rhs,
+        window_strides,
+        padding,
+        lhs_dilation,
+        rhs_dilation,
+        dimension_numbers,
+        feature_group_count,
+        batch_group_count,
+        precision_config=precision_config)
 
   def Sort(self, operand, dimension=-1):
     """Enqueues a sort operation onto the computation."""
@@ -1657,6 +1676,16 @@
     self.output_spatial_dimensions = []
 
 
+class PrecisionConfig(object):
+  """Python representation of a xla.PrecisionConfig protobuf."""
+  __slots__ = ('operand_precision',)
+
+  Precision = _xla.PrecisionConfig_Precision
+
+  def __init__(self):
+    self.operand_precision = []
+
+
 class GatherDimensionNumbers(object):
   """Python representation of a xla.GatherDimensionNumbers protobuf."""
   __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 4a90cc3..e6cfd46 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -716,6 +716,22 @@
     c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
     self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
 
+  def testDotGeneralWithPrecisionConfig(self):
+    c = self._NewComputation()
+    rng = np.random.RandomState(0)
+    lhs = NumpyArrayF32(rng.randn(10, 3, 4))
+    rhs = NumpyArrayF32(rng.randn(10, 4, 5))
+    dimension_numbers = (([2], [1]), ([0], [0]))
+    config = xla_client.PrecisionConfig()
+    config.operand_precision.append(config.Precision.HIGH)
+    config.operand_precision.append(config.Precision.HIGHEST)
+    c.DotGeneral(
+        c.Constant(lhs),
+        c.Constant(rhs),
+        dimension_numbers,
+        precision_config=config)
+    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+
   def testConvF32Same(self):
     c = self._NewComputation()
     a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
@@ -784,6 +800,36 @@
     ]]])
     self._ExecuteAndCompareClose(c, expected=result)
 
+  def testConvGeneralDilatedF32WithPrecisionConfig(self):
+    c = self._NewComputation()
+    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+    lhs = a(1, 1, 2, 3)
+    rhs = a(1, 1, 1, 2) * 10
+    strides = [1, 1]
+    pads = [(1, 0), (0, 1)]
+    lhs_dilation = (2, 1)
+    rhs_dilation = (1, 1)
+    dimension_numbers = ("NCHW", "OIHW", "NCHW")
+    config = xla_client.PrecisionConfig()
+    config.operand_precision.append(config.Precision.HIGHEST)
+    config.operand_precision.append(config.Precision.DEFAULT)
+    c.ConvGeneralDilated(
+        c.Constant(lhs),
+        c.Constant(rhs),
+        strides,
+        pads,
+        lhs_dilation,
+        rhs_dilation,
+        dimension_numbers,
+        precision_config=config)
+    result = np.array([[[
+        [0., 0., 0.],
+        [10., 20., 0.],
+        [0., 0., 0.],
+        [40., 50., 0.],
+    ]]])
+    self._ExecuteAndCompareClose(c, expected=result)
+
   def testConvGeneralDilatedPermutedF32(self):
     c = self._NewComputation()
     a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")