From e3be70aa9d472acaeeeea161bc609dd959decc9b Mon Sep 17 00:00:00 2001 From: Tomer Kaftan <kaftan@google.com> Date: Wed, 22 Jul 2020 14:30:00 -0700 Subject: [PATCH] Add a `convert_to_tensor` to the start of Tensor.__getitem__ (_slice_helper) to make sure it dispatches directly, rather than letting the nested tf.strided_slice trigger dispatching. This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem. (Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code) To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to: * Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices. * Add serialization/deserialization support for `ellipsis` objects in Keras ------------------------ Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors: * add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not. * Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be. PiperOrigin-RevId: 322655930 Change-Id: I35417577199393c016f753be685bf2926d62e753 --- tensorflow/python/keras/layers/core.py | 94 +++++++ .../keras/layers/tensorflow_op_layer_test.py | 233 ++++++++++++++++++ .../keras/saving/saved_model/json_utils.py | 6 + tensorflow/python/ops/array_ops.py | 2 + tensorflow/python/util/dispatch_test.py | 78 +++++- tensorflow/python/util/serialization.py | 3 + 6 files changed, 411 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 155af8d2398..2d69782a1cf 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy import functools import operator @@ -1422,3 +1423,96 @@ class KerasOpDispatcher(dispatch.GlobalOpDispatcher): return self.NOT_SUPPORTED KerasOpDispatcher().register() + +SliceTuple = collections.namedtuple('SliceTuple', ['start', 'stop', 'step']) + + +def _slice_to_named_tuple(x): + if isinstance(x, slice): + return SliceTuple(x.start, x.stop, x.step) + return x + + +def _named_tuple_to_slice(x): + if type(x).__name__ == 'SliceTuple': + return slice(x[0], x[1], x[2]) + return x + + +class SlicingOpLambda(TFOpLambda): + """Wraps TF API symbols in a `Layer` object. + + It is inserted by the Functional API construction whenever users call + a supported TF symbol on KerasTensors. + + Like Lambda layers, this layer tries to raise warnings when it detects users + explicitly use variables in the call. (To let them know + that the layer will not capture the variables). + + This is useful in the case where users do something like: + x = keras.Input(...) + y = tf.Variable(...) + out = x * tf_variable + """ + + @trackable.no_automatic_dependency_tracking + def __init__(self, function, **kwargs): + super(SlicingOpLambda, self).__init__(function, **kwargs) + + original_call = self.call + # Decorate the function to produce this layer's call method + def _call_wrapper(*args, **kwargs): + # Turn any slice nametuples in the args back into `slice` objects. + # This conversion cannot use nest.flatten/map_structure, + # because namedtuples are flattened by nest while slices aren't. + # So, map_structure would only see the individual elements in the + # namedtuple. + # This can't use map_structure_up_to either because the 'shallowness' of + # the shallow tree would have to vary depending on if only one dim or + # multiple are being sliced. + new_args = [] + for arg in args: + arg = _named_tuple_to_slice(arg) + if isinstance(arg, (list, tuple)): + new_arg = [] + for sub_arg in arg: + new_arg.append(_named_tuple_to_slice(sub_arg)) + arg = new_arg + new_args.append(arg) + + # Handle the kwargs too. + new_kwargs = {} + for key, value in kwargs.items(): + value = _named_tuple_to_slice(value) + if isinstance(value, (list, tuple)): + new_value = [] + for v in value: + new_value.append(_named_tuple_to_slice(v)) + value = new_value + new_kwargs[key] = value + + return original_call(*new_args, **new_kwargs) + self.call = tf_decorator.make_decorator(original_call, _call_wrapper) + + +class TFSlicingOpDispatcher(dispatch.OpDispatcher): + """A global dispatcher that allows building a functional model with TF Ops.""" + + def __init__(self, op): + self.op = op + + def handle(self, args, kwargs): + """Handle the specified operation with the specified arguments.""" + args = nest.map_structure(_slice_to_named_tuple, args) + kwargs = nest.map_structure(_slice_to_named_tuple, kwargs) + if any( + isinstance(x, keras_tensor.KerasTensor) + for x in nest.flatten([args, kwargs])): + return SlicingOpLambda(self.op)(*args, **kwargs) + else: + return self.NOT_SUPPORTED + +for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access + array_ops.boolean_mask, + array_ops.boolean_mask_v2]: + TFSlicingOpDispatcher(slicing_op).register(slicing_op) diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index cb044260106..817e746bc70 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.saving import model_config from tensorflow.python.ops import array_ops @@ -294,6 +295,238 @@ class AutoLambdaTest(keras_parameterized.TestCase): self.assertAllEqual([layer.name for layer in model.layers], [layer.name for layer in new_model.layers]) + def test_getitem_slice_with_step_only(self): + if not context.executing_eagerly(): + self.skipTest('Complex slicing like this fails in v1') + inp = keras.Input(shape=(4, 3, 8)) + slice_step = keras.Input(shape=(), dtype='int32') + + out = inp[..., ::slice_step[0]] + model = keras.Model( + inputs=[inp, slice_step], + outputs=out) + model.compile( + adam.Adam(0.001), + 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + batch_size = 7 + step = 3 + x = array_ops.stack([ + math_ops.range(8) for _ in range(batch_size)]) + args = [x, constant_op.constant(step, shape=(batch_size,))] + expected = array_ops.stack([ + math_ops.range(8)[::step] for _ in range(batch_size)]) + + if keras_tensor.keras_tensors_enabled(): + self.assertIn('tf.__operators__.getitem', ( + x.name for x in model.layers)) + self.assertNotIn('tf.strided_slice', ( + x.name for x in model.layers)) + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + # Make sure it can be successfully saved and loaded + config = model.get_config() + model = keras.Model.from_config(config) + + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + def test_getitem_slice_real_tensor(self): + if not context.executing_eagerly(): + self.skipTest('Complex slicing like this fails in v1') + x = math_ops.range(10.0) + slice_stop = keras.Input(shape=(), dtype='int32') + + out = x[:slice_stop[0]] + model = keras.Model( + inputs=slice_stop, + outputs=out) + model.compile( + adam.Adam(0.001), + 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + batch_size = 7 + stop = 6 + args = constant_op.constant(stop, shape=(batch_size,)) + expected = x[:stop] + + if keras_tensor.keras_tensors_enabled(): + self.assertIn('tf.__operators__.getitem', ( + x.name for x in model.layers)) + # TODO(b/161925288): Fix the dispatch triggering then uncomment: + # self.assertNotIn('tf.strided_slice', ( + # x.name for x in model.layers)) + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + # TODO(b/161925288): Fix the bug then uncomment: + # # Make sure it can be successfully saved and loaded + # config = model.get_config() + # model = keras.Model.from_config(config) + + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + def test_getitem_index_real_tensor(self): + if not context.executing_eagerly(): + self.skipTest('Complex slicing like this fails in v1') + x = math_ops.range(10.0) + slice_stop = keras.Input(shape=(), dtype='int32') + + out = x[slice_stop[0]] + model = keras.Model( + inputs=slice_stop, + outputs=out) + model.compile( + adam.Adam(0.001), + 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + batch_size = 7 + index = 6 + args = constant_op.constant(index, shape=(batch_size,)) + expected = x[index] + + if keras_tensor.keras_tensors_enabled(): + self.assertIn('tf.__operators__.getitem', ( + x.name for x in model.layers)) + # TODO(b/161925288): Fix the bug then uncomment: + # self.assertNotIn('tf.strided_slice', ( + # x.name for x in model.layers)) + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + # TODO(b/161925288): Fix the bug then uncomment: + # # Make sure it can be successfully saved and loaded + # config = model.get_config() + # model = keras.Model.from_config(config) + + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + def test_getitem_slice_with_stop_only(self): + if not context.executing_eagerly(): + self.skipTest('Complex slicing like this fails in v1') + inp = keras.Input(shape=(4, 3, 8)) + slice_stop = keras.Input(shape=(), dtype='int32') + + out = inp[:slice_stop[0]] + model = keras.Model( + inputs=[inp, slice_stop], + outputs=out) + model.compile( + adam.Adam(0.001), + 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + batch_size = 7 + stop = 6 + x = array_ops.stack([ + math_ops.range(8) for _ in range(batch_size)]) + args = [x, constant_op.constant(stop, shape=(batch_size,))] + expected = x[:stop] + + if keras_tensor.keras_tensors_enabled(): + self.assertIn('tf.__operators__.getitem', ( + x.name for x in model.layers)) + self.assertNotIn('tf.strided_slice', ( + x.name for x in model.layers)) + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + # Make sure it can be successfully saved and loaded + config = model.get_config() + model = keras.Model.from_config(config) + + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + def test_getitem_slice_with_stop_and_ellipsis_only(self): + if not context.executing_eagerly(): + self.skipTest('Complex slicing like this fails in v1') + inp = keras.Input(shape=(4, 3, 8)) + slice_stop = keras.Input(shape=(), dtype='int32') + + out = inp[..., :slice_stop[0]] + model = keras.Model( + inputs=[inp, slice_stop], + outputs=out) + model.compile( + adam.Adam(0.001), + 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + batch_size = 7 + stop = 6 + x = array_ops.stack([ + math_ops.range(8) for _ in range(batch_size)]) + args = [x, constant_op.constant(stop, shape=(batch_size,))] + expected = array_ops.stack([ + math_ops.range(8)[:stop] for _ in range(batch_size)]) + + if keras_tensor.keras_tensors_enabled(): + self.assertIn('tf.__operators__.getitem', ( + x.name for x in model.layers)) + self.assertNotIn('tf.strided_slice', ( + x.name for x in model.layers)) + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + # Make sure it can be successfully saved and loaded + config = model.get_config() + model = keras.Model.from_config(config) + + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + def test_getitem_complex_slicing(self): + if not context.executing_eagerly(): + self.skipTest('Complex slicing like this fails in v1') + inp = keras.Input(shape=(4, 3, 8)) + first_dim = keras.Input(shape=(), dtype='int32') + slice_start = keras.Input(shape=(), dtype='int32') + slice_stop = keras.Input(shape=(), dtype='int32') + slice_stride = keras.Input(shape=(), dtype='int32') + + out = inp[..., first_dim[0], slice_start[0]:slice_stop[0]:slice_stride[0]] + model = keras.Model( + inputs=[inp, first_dim, slice_start, slice_stop, slice_stride], + outputs=out) + model.compile( + adam.Adam(0.001), + 'mse', + run_eagerly=testing_utils.should_run_eagerly()) + batch_size = 7 + start = 1 + stop = 6 + step = 2 + x = array_ops.stack([array_ops.stack([array_ops.stack([ + math_ops.range(8) + for _ in range(3)]) for _ in range(4)]) for _ in range(batch_size)]) + args = [x, + constant_op.constant(0, shape=(batch_size,)), + constant_op.constant(start, shape=(batch_size,)), + constant_op.constant(stop, shape=(batch_size,)), + constant_op.constant(step, shape=(batch_size,))] + # Slice the innermost dim. only grab one index from the second-to-innermost + # dim, removing that dim from the shape. + expected = array_ops.stack([array_ops.stack([ + math_ops.range(8)[start:stop:step] + for _ in range(4)]) for _ in range(batch_size)]) + + if keras_tensor.keras_tensors_enabled(): + self.assertIn('tf.__operators__.getitem', ( + x.name for x in model.layers)) + self.assertNotIn('tf.strided_slice', ( + x.name for x in model.layers)) + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + + # Make sure it can be successfully saved and loaded + config = model.get_config() + model = keras.Model.from_config(config) + + self.assertAllEqual(model(args), expected) + self.assertAllEqual(model.predict(args, batch_size=batch_size), expected) + def test_numerical_correctness_simple(self): x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]]) inputs = keras.Input(shape=(4,)) diff --git a/tensorflow/python/keras/saving/saved_model/json_utils.py b/tensorflow/python/keras/saving/saved_model/json_utils.py index 4e4b671697a..d06e4180564 100644 --- a/tensorflow/python/keras/saving/saved_model/json_utils.py +++ b/tensorflow/python/keras/saving/saved_model/json_utils.py @@ -70,11 +70,14 @@ def decode(json_string): def _decode_helper(obj): + """A decoding helper that is TF-object aware.""" if isinstance(obj, dict) and 'class_name' in obj: if obj['class_name'] == 'TensorShape': return tensor_shape.TensorShape(obj['items']) elif obj['class_name'] == '__tuple__': return tuple(_decode_helper(i) for i in obj['items']) + elif obj['class_name'] == '__ellipsis__': + return Ellipsis return obj @@ -122,6 +125,9 @@ def get_json_type(obj): if isinstance(obj, collections_abc.Mapping): return dict(obj) + if obj is Ellipsis: + return {'class_name': '__ellipsis__'} + if isinstance(obj, wrapt.ObjectProxy): return obj.__wrapped__ diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 8e9bc1ef4d3..e9f32dec6b8 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -955,6 +955,8 @@ def _slice_helper(tensor, slice_spec, var=None): TypeError: If the slice indices aren't int, slice, ellipsis, tf.newaxis or scalar int32/int64 tensors. """ + tensor = ops.convert_to_tensor(tensor) + if isinstance(slice_spec, bool) or \ (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \ (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool): diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py index cc4fed0abb7..2b3946ce9f7 100644 --- a/tensorflow/python/util/dispatch_test.py +++ b/tensorflow/python/util/dispatch_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.proto_ops import decode_proto @@ -28,6 +29,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import get_canonical_name_for_symbol from tensorflow.python.util.tf_export import tf_export @@ -68,10 +70,38 @@ class TensorTracer(object): ["{}={}".format(name, x) for (name, x) in self.kwargs.items()]) return "{}({})".format(self.name, ", ".join(args)) + @classmethod + def _overload_all_operators(cls): # pylint: disable=invalid-name + """Register overloads for all operators.""" + for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + cls._overload_operator(operator) + + @classmethod + def _overload_operator(cls, operator): # pylint: disable=invalid-name + """Overload an operator with the same overloading as `ops.Tensor`.""" + tensor_oper = getattr(ops.Tensor, operator) + + # Compatibility with Python 2: + # Python 2 unbound methods have type checks for the first arg, + # so we need to extract the underlying function + tensor_oper = getattr(tensor_oper, "__func__", tensor_oper) + setattr(cls, operator, tensor_oper) + +TensorTracer._overload_all_operators() # pylint: disable=protected-access + class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher): """Global op dispatcher for TensorTracer.""" + def _flatten_with_slice_flattening(self, x): + flat = [] + for val in nest.flatten(x): + if isinstance(val, slice): + flat.extend((val.start, val.stop, val.step)) + else: + flat.append(val) + return flat + def handle(self, op, args, kwargs): # Dispatcher only applies if at least one arg is a TensorTracer. if not (any(self.is_tensor_tracer_arg(x) for x in args) or @@ -82,11 +112,8 @@ class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher): return TensorTracer(symbol_name, args, kwargs) def is_tensor_tracer_arg(self, value): - if isinstance(value, TensorTracer): - return True - if isinstance(value, (list, tuple)): - if any(isinstance(x, TensorTracer) for x in value): - return True + return any(isinstance(x, TensorTracer) for x in + self._flatten_with_slice_flattening(value)) @test_util.run_all_in_graph_and_eager_modes @@ -214,5 +241,46 @@ class DispatchTest(test_util.TensorFlowTestCase): # Clean up. dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers + def testGlobalDispatcherGetItem(self): + original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS + try: + TensorTracerOpDispatcher().register() + + x = TensorTracer("x") + trace = x[0] + self.assertEqual( + str(trace), + "__operators__.getitem(x, 0)") + + x = TensorTracer("x") + y = TensorTracer("y") + trace = x[y] + self.assertEqual( + str(trace), + "__operators__.getitem(x, y)") + + x = TensorTracer("x") + y = TensorTracer("y") + trace = x[:y] # pylint: disable=invalid-slice-index + self.assertEqual( + str(trace), + "__operators__.getitem(x, slice(None, y, None))") + + x = array_ops.ones(shape=(3, 3)) + y = TensorTracer("y") + trace = x[y] + self.assertEqual( + str(trace), + "__operators__.getitem(%s, y)" % x) + + trace = x[:y] # pylint: disable=invalid-slice-index + self.assertEqual( + str(trace), + "__operators__.getitem(%s, slice(None, y, None))" % x) + + finally: + # Clean up. + dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py index 3b1713b4c61..e35d5ff5d5d 100644 --- a/tensorflow/python/util/serialization.py +++ b/tensorflow/python/util/serialization.py @@ -70,6 +70,9 @@ def get_json_type(obj): if isinstance(obj, collections_abc.Mapping): return dict(obj) + if obj is Ellipsis: + return {'class_name': '__ellipsis__'} + if isinstance(obj, wrapt.ObjectProxy): return obj.__wrapped__