Adds get_sharding_tile_shape() function.
Adds many unit tests for xla_sharding.Sharding and helper functions, making use of the new get_sharding_tile_shape() function. PiperOrigin-RevId: 361245107 Change-Id: I85ee889c9f2f59803cf59a5c2972fc3e75d3f037
This commit is contained in:
parent
081dcedc8a
commit
128b50a574
@ -367,6 +367,28 @@ def get_tensor_sharding(tensor):
|
||||
return None
|
||||
|
||||
|
||||
def get_sharding_tile_shape(sharding):
|
||||
"""Returns the tile assignment shape for a sharded Tensor.
|
||||
|
||||
Args:
|
||||
sharding: a serialized OpSharding message describing the layout of a
|
||||
sharded Tensor.
|
||||
|
||||
Returns:
|
||||
A list, for each dimension of the sharded Tensor, of the number of shards
|
||||
into which it has been split. Returns None if the input indicates no tile
|
||||
assignments.
|
||||
"""
|
||||
if sharding is None:
|
||||
return None
|
||||
sharding_message = xla_data_pb2.OpSharding()
|
||||
sharding_message.ParseFromString(sharding)
|
||||
if sharding_message.tile_assignment_dimensions:
|
||||
return sharding_message.tile_assignment_dimensions
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def auto_to_manual_spmd_partition(tensor, manual_sharding):
|
||||
"""Switches from automatic SPMD partitioning to manual partitioning.
|
||||
|
||||
|
@ -17,12 +17,17 @@
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf.message import DecodeError
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
class ShardingTest(absltest.TestCase):
|
||||
class ShardingTest(test_util.TensorFlowTestCase):
|
||||
"""Tests for member functions of the class xla_sharding.Sharding."""
|
||||
|
||||
def test_sharding_is_default_constructable(self):
|
||||
sharding = xla_sharding.Sharding()
|
||||
@ -52,5 +57,110 @@ class ShardingTest(absltest.TestCase):
|
||||
xla_sharding.Sharding)
|
||||
|
||||
|
||||
class XlaShardingTest(test_util.TensorFlowTestCase):
|
||||
"""Tests for non-member functions in the module xla_sharding.py."""
|
||||
|
||||
def test_replicate_annotates_tensor_correctly(self):
|
||||
|
||||
@def_function.function
|
||||
def replicate_helper(tensor):
|
||||
replicated_tensor = xla_sharding.replicate(
|
||||
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
|
||||
replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor)
|
||||
self.assertIsNotNone(replicated_sharding)
|
||||
self.assertIsNone(
|
||||
xla_sharding.get_sharding_tile_shape(replicated_sharding))
|
||||
return replicated_tensor
|
||||
|
||||
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
|
||||
result = replicate_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
|
||||
self.assertAllEqual(in_tensor, result)
|
||||
|
||||
def test_tile_annotates_tensor_correctly(self):
|
||||
|
||||
@def_function.function
|
||||
def tile_helper(tensor):
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
|
||||
tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6]))
|
||||
self.assertIsInstance(tiled_tensor, ops.Tensor)
|
||||
tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor)
|
||||
tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding)
|
||||
# This is the shape of the tile assignment [2, 1, 6]
|
||||
expected_shape = [3]
|
||||
self.assertEqual(expected_shape, tile_shape)
|
||||
return tiled_tensor
|
||||
|
||||
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
|
||||
result = tile_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
|
||||
self.assertAllEqual(in_tensor, result)
|
||||
|
||||
def test_split_annotates_tensor_correctly(self):
|
||||
|
||||
@def_function.function
|
||||
def split_helper(tensor):
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
|
||||
split_tensor = xla_sharding.split(tensor, 2, 3)
|
||||
self.assertIsInstance(split_tensor, ops.Tensor)
|
||||
split_sharding = xla_sharding.get_tensor_sharding(split_tensor)
|
||||
split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
|
||||
expected_shape = [1, 1, 3]
|
||||
self.assertEqual(expected_shape, split_shape)
|
||||
return split_tensor
|
||||
|
||||
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
|
||||
result = split_helper(
|
||||
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
|
||||
self.assertAllEqual(in_tensor, result)
|
||||
|
||||
def test_split_raises_error_with_incommensurate_dimensions(self):
|
||||
|
||||
@def_function.function
|
||||
def split_helper(tensor):
|
||||
split_tensor = xla_sharding.split(tensor, 0, 8)
|
||||
return split_tensor
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = split_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
|
||||
|
||||
# TODO(drm): Modify split() so that this call raises an error since
|
||||
# 8 does not divide 9 (currently only checks that 8 is smaller than 9,
|
||||
# which it is, but this is not good for splitting).
|
||||
# with self.assertRaises(ValueError):
|
||||
# _ = split_helper(array_ops.ones([9, 5, 6], dtype=dtypes.float32))
|
||||
|
||||
def test_copy_sharding_succeeds_with_identically_shaped_tensors(self):
|
||||
|
||||
@def_function.function
|
||||
def copy_helper(tensor):
|
||||
tensor_src = array_ops.identity(tensor)
|
||||
tensor_src = xla_sharding.split(tensor, 2, 3)
|
||||
sharding_src = xla_sharding.get_tensor_sharding(tensor_src)
|
||||
shape_src = xla_sharding.get_sharding_tile_shape(sharding_src)
|
||||
self.assertEqual([1, 1, 3], shape_src)
|
||||
|
||||
tensor_dest = array_ops.identity(tensor)
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest))
|
||||
|
||||
xla_sharding.copy_sharding(tensor_src, tensor_dest)
|
||||
sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest)
|
||||
shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest)
|
||||
self.assertEqual([1, 1, 3], shape_dest)
|
||||
return tensor_dest
|
||||
|
||||
in_tensor = array_ops.ones([4, 5, 6], dtype=dtypes.float32)
|
||||
result = copy_helper(
|
||||
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
|
||||
self.assertAllEqual(in_tensor, result)
|
||||
|
||||
def test_get_sharding_tile_shape_returns_none_on_none_input(self):
|
||||
self.assertIsNone(xla_sharding.get_sharding_tile_shape(None))
|
||||
|
||||
def test_get_sharding_tile_shape_raises_error_on_nonparsable_input(self):
|
||||
bad_proto_data = b'\x0f'
|
||||
with self.assertRaises(DecodeError):
|
||||
xla_sharding.get_sharding_tile_shape(bad_proto_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user