From 128b50a574298a8f325b9eaad1f2c768b250df66 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Fri, 5 Mar 2021 16:07:51 -0800 Subject: [PATCH] Adds get_sharding_tile_shape() function. Adds many unit tests for xla_sharding.Sharding and helper functions, making use of the new get_sharding_tile_shape() function. PiperOrigin-RevId: 361245107 Change-Id: I85ee889c9f2f59803cf59a5c2972fc3e75d3f037 --- .../experimental/xla_sharding/xla_sharding.py | 22 ++++ .../xla_sharding/xla_sharding_test.py | 112 +++++++++++++++++- 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 83eec095889..0f1dcd89302 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -367,6 +367,28 @@ def get_tensor_sharding(tensor): return None +def get_sharding_tile_shape(sharding): + """Returns the tile assignment shape for a sharded Tensor. + + Args: + sharding: a serialized OpSharding message describing the layout of a + sharded Tensor. + + Returns: + A list, for each dimension of the sharded Tensor, of the number of shards + into which it has been split. Returns None if the input indicates no tile + assignments. + """ + if sharding is None: + return None + sharding_message = xla_data_pb2.OpSharding() + sharding_message.ParseFromString(sharding) + if sharding_message.tile_assignment_dimensions: + return sharding_message.tile_assignment_dimensions + else: + return None + + def auto_to_manual_spmd_partition(tensor, manual_sharding): """Switches from automatic SPMD partitioning to manual partitioning. diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py index 6cb404bc1ed..826ae82628f 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding_test.py @@ -17,12 +17,17 @@ from absl.testing import absltest import numpy as np +from google.protobuf.message import DecodeError from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -class ShardingTest(absltest.TestCase): +class ShardingTest(test_util.TensorFlowTestCase): + """Tests for member functions of the class xla_sharding.Sharding.""" def test_sharding_is_default_constructable(self): sharding = xla_sharding.Sharding() @@ -52,5 +57,110 @@ class ShardingTest(absltest.TestCase): xla_sharding.Sharding) +class XlaShardingTest(test_util.TensorFlowTestCase): + """Tests for non-member functions in the module xla_sharding.py.""" + + def test_replicate_annotates_tensor_correctly(self): + + @def_function.function + def replicate_helper(tensor): + replicated_tensor = xla_sharding.replicate( + array_ops.ones([4, 5, 6], dtype=dtypes.float32)) + self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) + replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor) + self.assertIsNotNone(replicated_sharding) + self.assertIsNone( + xla_sharding.get_sharding_tile_shape(replicated_sharding)) + return replicated_tensor + + in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32) + result = replicate_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32)) + self.assertAllEqual(in_tensor, result) + + def test_tile_annotates_tensor_correctly(self): + + @def_function.function + def tile_helper(tensor): + self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) + tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6])) + self.assertIsInstance(tiled_tensor, ops.Tensor) + tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor) + tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding) + # This is the shape of the tile assignment [2, 1, 6] + expected_shape = [3] + self.assertEqual(expected_shape, tile_shape) + return tiled_tensor + + in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32) + result = tile_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32)) + self.assertAllEqual(in_tensor, result) + + def test_split_annotates_tensor_correctly(self): + + @def_function.function + def split_helper(tensor): + self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) + split_tensor = xla_sharding.split(tensor, 2, 3) + self.assertIsInstance(split_tensor, ops.Tensor) + split_sharding = xla_sharding.get_tensor_sharding(split_tensor) + split_shape = xla_sharding.get_sharding_tile_shape(split_sharding) + expected_shape = [1, 1, 3] + self.assertEqual(expected_shape, split_shape) + return split_tensor + + in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32) + result = split_helper( + array_ops.ones([4, 5, 6], dtype=dtypes.float32)) + self.assertAllEqual(in_tensor, result) + + def test_split_raises_error_with_incommensurate_dimensions(self): + + @def_function.function + def split_helper(tensor): + split_tensor = xla_sharding.split(tensor, 0, 8) + return split_tensor + + with self.assertRaises(ValueError): + _ = split_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32)) + + # TODO(drm): Modify split() so that this call raises an error since + # 8 does not divide 9 (currently only checks that 8 is smaller than 9, + # which it is, but this is not good for splitting). + # with self.assertRaises(ValueError): + # _ = split_helper(array_ops.ones([9, 5, 6], dtype=dtypes.float32)) + + def test_copy_sharding_succeeds_with_identically_shaped_tensors(self): + + @def_function.function + def copy_helper(tensor): + tensor_src = array_ops.identity(tensor) + tensor_src = xla_sharding.split(tensor, 2, 3) + sharding_src = xla_sharding.get_tensor_sharding(tensor_src) + shape_src = xla_sharding.get_sharding_tile_shape(sharding_src) + self.assertEqual([1, 1, 3], shape_src) + + tensor_dest = array_ops.identity(tensor) + self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest)) + + xla_sharding.copy_sharding(tensor_src, tensor_dest) + sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest) + shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest) + self.assertEqual([1, 1, 3], shape_dest) + return tensor_dest + + in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32) + result = copy_helper( + array_ops.ones([4, 5, 6], dtype=dtypes.float32)) + self.assertAllEqual(in_tensor, result) + + def test_get_sharding_tile_shape_returns_none_on_none_input(self): + self.assertIsNone(xla_sharding.get_sharding_tile_shape(None)) + + def test_get_sharding_tile_shape_raises_error_on_nonparsable_input(self): + bad_proto_data = b'\x0f' + with self.assertRaises(DecodeError): + xla_sharding.get_sharding_tile_shape(bad_proto_data) + + if __name__ == '__main__': absltest.main()