From 128b50a574298a8f325b9eaad1f2c768b250df66 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 5 Mar 2021 16:07:51 -0800
Subject: [PATCH] Adds get_sharding_tile_shape() function.

Adds many unit tests for xla_sharding.Sharding and helper functions, making use of the new get_sharding_tile_shape() function.

PiperOrigin-RevId: 361245107
Change-Id: I85ee889c9f2f59803cf59a5c2972fc3e75d3f037
---
 .../experimental/xla_sharding/xla_sharding.py |  22 ++++
 .../xla_sharding/xla_sharding_test.py         | 112 +++++++++++++++++-
 2 files changed, 133 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
index 83eec095889..0f1dcd89302 100644
--- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
@@ -367,6 +367,28 @@ def get_tensor_sharding(tensor):
     return None
 
 
+def get_sharding_tile_shape(sharding):
+  """Returns the tile assignment shape for a sharded Tensor.
+
+  Args:
+    sharding: a serialized OpSharding message describing the layout of a
+      sharded Tensor.
+
+  Returns:
+    A list, for each dimension of the sharded Tensor, of the number of shards
+      into which it has been split. Returns None if the input indicates no tile
+      assignments.
+  """
+  if sharding is None:
+    return None
+  sharding_message = xla_data_pb2.OpSharding()
+  sharding_message.ParseFromString(sharding)
+  if sharding_message.tile_assignment_dimensions:
+    return sharding_message.tile_assignment_dimensions
+  else:
+    return None
+
+
 def auto_to_manual_spmd_partition(tensor, manual_sharding):
   """Switches from automatic SPMD partitioning to manual partitioning.
 
diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py
index 6cb404bc1ed..826ae82628f 100644
--- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py
@@ -17,12 +17,17 @@
 from absl.testing import absltest
 import numpy as np
 
+from google.protobuf.message import DecodeError
 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 
 
-class ShardingTest(absltest.TestCase):
+class ShardingTest(test_util.TensorFlowTestCase):
+  """Tests for member functions of the class xla_sharding.Sharding."""
 
   def test_sharding_is_default_constructable(self):
     sharding = xla_sharding.Sharding()
@@ -52,5 +57,110 @@ class ShardingTest(absltest.TestCase):
         xla_sharding.Sharding)
 
 
+class XlaShardingTest(test_util.TensorFlowTestCase):
+  """Tests for non-member functions in the module xla_sharding.py."""
+
+  def test_replicate_annotates_tensor_correctly(self):
+
+    @def_function.function
+    def replicate_helper(tensor):
+      replicated_tensor = xla_sharding.replicate(
+          array_ops.ones([4, 5, 6], dtype=dtypes.float32))
+      self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
+      replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor)
+      self.assertIsNotNone(replicated_sharding)
+      self.assertIsNone(
+          xla_sharding.get_sharding_tile_shape(replicated_sharding))
+      return replicated_tensor
+
+    in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
+    result = replicate_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
+    self.assertAllEqual(in_tensor, result)
+
+  def test_tile_annotates_tensor_correctly(self):
+
+    @def_function.function
+    def tile_helper(tensor):
+      self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
+      tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6]))
+      self.assertIsInstance(tiled_tensor, ops.Tensor)
+      tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor)
+      tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding)
+      # This is the shape of the tile assignment [2, 1, 6]
+      expected_shape = [3]
+      self.assertEqual(expected_shape, tile_shape)
+      return tiled_tensor
+
+    in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
+    result = tile_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
+    self.assertAllEqual(in_tensor, result)
+
+  def test_split_annotates_tensor_correctly(self):
+
+    @def_function.function
+    def split_helper(tensor):
+      self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
+      split_tensor = xla_sharding.split(tensor, 2, 3)
+      self.assertIsInstance(split_tensor, ops.Tensor)
+      split_sharding = xla_sharding.get_tensor_sharding(split_tensor)
+      split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
+      expected_shape = [1, 1, 3]
+      self.assertEqual(expected_shape, split_shape)
+      return split_tensor
+
+    in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
+    result = split_helper(
+        array_ops.ones([4, 5, 6], dtype=dtypes.float32))
+    self.assertAllEqual(in_tensor, result)
+
+  def test_split_raises_error_with_incommensurate_dimensions(self):
+
+    @def_function.function
+    def split_helper(tensor):
+      split_tensor = xla_sharding.split(tensor, 0, 8)
+      return split_tensor
+
+    with self.assertRaises(ValueError):
+      _ = split_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
+
+    # TODO(drm): Modify split() so that this call raises an error since
+    # 8 does not divide 9 (currently only checks that 8 is smaller than 9,
+    # which it is, but this is not good for splitting).
+    # with self.assertRaises(ValueError):
+    #   _ = split_helper(array_ops.ones([9, 5, 6], dtype=dtypes.float32))
+
+  def test_copy_sharding_succeeds_with_identically_shaped_tensors(self):
+
+    @def_function.function
+    def copy_helper(tensor):
+      tensor_src = array_ops.identity(tensor)
+      tensor_src = xla_sharding.split(tensor, 2, 3)
+      sharding_src = xla_sharding.get_tensor_sharding(tensor_src)
+      shape_src = xla_sharding.get_sharding_tile_shape(sharding_src)
+      self.assertEqual([1, 1, 3], shape_src)
+
+      tensor_dest = array_ops.identity(tensor)
+      self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest))
+
+      xla_sharding.copy_sharding(tensor_src, tensor_dest)
+      sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest)
+      shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest)
+      self.assertEqual([1, 1, 3], shape_dest)
+      return tensor_dest
+
+    in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
+    result = copy_helper(
+        array_ops.ones([4, 5, 6], dtype=dtypes.float32))
+    self.assertAllEqual(in_tensor, result)
+
+  def test_get_sharding_tile_shape_returns_none_on_none_input(self):
+    self.assertIsNone(xla_sharding.get_sharding_tile_shape(None))
+
+  def test_get_sharding_tile_shape_raises_error_on_nonparsable_input(self):
+    bad_proto_data = b'\x0f'
+    with self.assertRaises(DecodeError):
+      xla_sharding.get_sharding_tile_shape(bad_proto_data)
+
+
 if __name__ == '__main__':
   absltest.main()