Adds get_sharding_tile_shape() function.

Adds many unit tests for xla_sharding.Sharding and helper functions, making use of the new get_sharding_tile_shape() function.

PiperOrigin-RevId: 361245107
Change-Id: I85ee889c9f2f59803cf59a5c2972fc3e75d3f037
This commit is contained in:
A. Unique TensorFlower 2021-03-05 16:07:51 -08:00 committed by TensorFlower Gardener
parent 081dcedc8a
commit 128b50a574
2 changed files with 133 additions and 1 deletions

View File

@ -367,6 +367,28 @@ def get_tensor_sharding(tensor):
return None
def get_sharding_tile_shape(sharding):
"""Returns the tile assignment shape for a sharded Tensor.
Args:
sharding: a serialized OpSharding message describing the layout of a
sharded Tensor.
Returns:
A list, for each dimension of the sharded Tensor, of the number of shards
into which it has been split. Returns None if the input indicates no tile
assignments.
"""
if sharding is None:
return None
sharding_message = xla_data_pb2.OpSharding()
sharding_message.ParseFromString(sharding)
if sharding_message.tile_assignment_dimensions:
return sharding_message.tile_assignment_dimensions
else:
return None
def auto_to_manual_spmd_partition(tensor, manual_sharding):
"""Switches from automatic SPMD partitioning to manual partitioning.

View File

@ -17,12 +17,17 @@
from absl.testing import absltest
import numpy as np
from google.protobuf.message import DecodeError
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
class ShardingTest(absltest.TestCase):
class ShardingTest(test_util.TensorFlowTestCase):
"""Tests for member functions of the class xla_sharding.Sharding."""
def test_sharding_is_default_constructable(self):
sharding = xla_sharding.Sharding()
@ -52,5 +57,110 @@ class ShardingTest(absltest.TestCase):
xla_sharding.Sharding)
class XlaShardingTest(test_util.TensorFlowTestCase):
"""Tests for non-member functions in the module xla_sharding.py."""
def test_replicate_annotates_tensor_correctly(self):
@def_function.function
def replicate_helper(tensor):
replicated_tensor = xla_sharding.replicate(
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor)
self.assertIsNotNone(replicated_sharding)
self.assertIsNone(
xla_sharding.get_sharding_tile_shape(replicated_sharding))
return replicated_tensor
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
result = replicate_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)
def test_tile_annotates_tensor_correctly(self):
@def_function.function
def tile_helper(tensor):
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6]))
self.assertIsInstance(tiled_tensor, ops.Tensor)
tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor)
tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding)
# This is the shape of the tile assignment [2, 1, 6]
expected_shape = [3]
self.assertEqual(expected_shape, tile_shape)
return tiled_tensor
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
result = tile_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)
def test_split_annotates_tensor_correctly(self):
@def_function.function
def split_helper(tensor):
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
split_tensor = xla_sharding.split(tensor, 2, 3)
self.assertIsInstance(split_tensor, ops.Tensor)
split_sharding = xla_sharding.get_tensor_sharding(split_tensor)
split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
expected_shape = [1, 1, 3]
self.assertEqual(expected_shape, split_shape)
return split_tensor
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
result = split_helper(
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)
def test_split_raises_error_with_incommensurate_dimensions(self):
@def_function.function
def split_helper(tensor):
split_tensor = xla_sharding.split(tensor, 0, 8)
return split_tensor
with self.assertRaises(ValueError):
_ = split_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
# TODO(drm): Modify split() so that this call raises an error since
# 8 does not divide 9 (currently only checks that 8 is smaller than 9,
# which it is, but this is not good for splitting).
# with self.assertRaises(ValueError):
# _ = split_helper(array_ops.ones([9, 5, 6], dtype=dtypes.float32))
def test_copy_sharding_succeeds_with_identically_shaped_tensors(self):
@def_function.function
def copy_helper(tensor):
tensor_src = array_ops.identity(tensor)
tensor_src = xla_sharding.split(tensor, 2, 3)
sharding_src = xla_sharding.get_tensor_sharding(tensor_src)
shape_src = xla_sharding.get_sharding_tile_shape(sharding_src)
self.assertEqual([1, 1, 3], shape_src)
tensor_dest = array_ops.identity(tensor)
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest))
xla_sharding.copy_sharding(tensor_src, tensor_dest)
sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest)
shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest)
self.assertEqual([1, 1, 3], shape_dest)
return tensor_dest
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
result = copy_helper(
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)
def test_get_sharding_tile_shape_returns_none_on_none_input(self):
self.assertIsNone(xla_sharding.get_sharding_tile_shape(None))
def test_get_sharding_tile_shape_raises_error_on_nonparsable_input(self):
bad_proto_data = b'\x0f'
with self.assertRaises(DecodeError):
xla_sharding.get_sharding_tile_shape(bad_proto_data)
if __name__ == '__main__':
absltest.main()