From e3be70aa9d472acaeeeea161bc609dd959decc9b Mon Sep 17 00:00:00 2001
From: Tomer Kaftan <kaftan@google.com>
Date: Wed, 22 Jul 2020 14:30:00 -0700
Subject: [PATCH] Add a `convert_to_tensor` to the start of Tensor.__getitem__
 (_slice_helper) to make sure it dispatches directly, rather than letting the
 nested tf.strided_slice trigger dispatching.

This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem.

(Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code)

To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to:

* Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices.

* Add serialization/deserialization support for `ellipsis` objects in Keras

------------------------
Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors:

* add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not.

* Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be.

PiperOrigin-RevId: 322655930
Change-Id: I35417577199393c016f753be685bf2926d62e753
---
 tensorflow/python/keras/layers/core.py        |  94 +++++++
 .../keras/layers/tensorflow_op_layer_test.py  | 233 ++++++++++++++++++
 .../keras/saving/saved_model/json_utils.py    |   6 +
 tensorflow/python/ops/array_ops.py            |   2 +
 tensorflow/python/util/dispatch_test.py       |  78 +++++-
 tensorflow/python/util/serialization.py       |   3 +
 6 files changed, 411 insertions(+), 5 deletions(-)

diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 155af8d2398..2d69782a1cf 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import copy
 import functools
 import operator
@@ -1422,3 +1423,96 @@ class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
       return self.NOT_SUPPORTED
 
 KerasOpDispatcher().register()
+
+SliceTuple = collections.namedtuple('SliceTuple', ['start', 'stop', 'step'])
+
+
+def _slice_to_named_tuple(x):
+  if isinstance(x, slice):
+    return SliceTuple(x.start, x.stop, x.step)
+  return x
+
+
+def _named_tuple_to_slice(x):
+  if type(x).__name__ == 'SliceTuple':
+    return slice(x[0], x[1], x[2])
+  return x
+
+
+class SlicingOpLambda(TFOpLambda):
+  """Wraps TF API symbols in a `Layer` object.
+
+  It is inserted by the Functional API construction whenever users call
+  a supported TF symbol on KerasTensors.
+
+  Like Lambda layers, this layer tries to raise warnings when it detects users
+  explicitly use variables in the call. (To let them know
+  that the layer will not capture the variables).
+
+  This is useful in the case where users do something like:
+  x = keras.Input(...)
+  y = tf.Variable(...)
+  out = x * tf_variable
+  """
+
+  @trackable.no_automatic_dependency_tracking
+  def __init__(self, function, **kwargs):
+    super(SlicingOpLambda, self).__init__(function, **kwargs)
+
+    original_call = self.call
+    # Decorate the function to produce this layer's call method
+    def _call_wrapper(*args, **kwargs):
+      # Turn any slice nametuples in the args back into `slice` objects.
+      # This conversion cannot use nest.flatten/map_structure,
+      # because namedtuples are flattened by nest while slices aren't.
+      # So, map_structure would only see the individual elements in the
+      # namedtuple.
+      # This can't use map_structure_up_to either because the 'shallowness' of
+      # the shallow tree would have to vary depending on if only one dim or
+      # multiple are being sliced.
+      new_args = []
+      for arg in args:
+        arg = _named_tuple_to_slice(arg)
+        if isinstance(arg, (list, tuple)):
+          new_arg = []
+          for sub_arg in arg:
+            new_arg.append(_named_tuple_to_slice(sub_arg))
+          arg = new_arg
+        new_args.append(arg)
+
+      # Handle the kwargs too.
+      new_kwargs = {}
+      for key, value in kwargs.items():
+        value = _named_tuple_to_slice(value)
+        if isinstance(value, (list, tuple)):
+          new_value = []
+          for v in value:
+            new_value.append(_named_tuple_to_slice(v))
+          value = new_value
+        new_kwargs[key] = value
+
+      return original_call(*new_args, **new_kwargs)
+    self.call = tf_decorator.make_decorator(original_call, _call_wrapper)
+
+
+class TFSlicingOpDispatcher(dispatch.OpDispatcher):
+  """A global dispatcher that allows building a functional model with TF Ops."""
+
+  def __init__(self, op):
+    self.op = op
+
+  def handle(self, args, kwargs):
+    """Handle the specified operation with the specified arguments."""
+    args = nest.map_structure(_slice_to_named_tuple, args)
+    kwargs = nest.map_structure(_slice_to_named_tuple, kwargs)
+    if any(
+        isinstance(x, keras_tensor.KerasTensor)
+        for x in nest.flatten([args, kwargs])):
+      return SlicingOpLambda(self.op)(*args, **kwargs)
+    else:
+      return self.NOT_SUPPORTED
+
+for slicing_op in [array_ops._slice_helper,  # pylint: disable=protected-access
+                   array_ops.boolean_mask,
+                   array_ops.boolean_mask_v2]:
+  TFSlicingOpDispatcher(slicing_op).register(slicing_op)
diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
index cb044260106..817e746bc70 100644
--- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
+++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.engine import keras_tensor
 from tensorflow.python.keras.optimizer_v2 import adam
 from tensorflow.python.keras.saving import model_config
 from tensorflow.python.ops import array_ops
@@ -294,6 +295,238 @@ class AutoLambdaTest(keras_parameterized.TestCase):
     self.assertAllEqual([layer.name for layer in model.layers],
                         [layer.name for layer in new_model.layers])
 
+  def test_getitem_slice_with_step_only(self):
+    if not context.executing_eagerly():
+      self.skipTest('Complex slicing like this fails in v1')
+    inp = keras.Input(shape=(4, 3, 8))
+    slice_step = keras.Input(shape=(), dtype='int32')
+
+    out = inp[..., ::slice_step[0]]
+    model = keras.Model(
+        inputs=[inp, slice_step],
+        outputs=out)
+    model.compile(
+        adam.Adam(0.001),
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    batch_size = 7
+    step = 3
+    x = array_ops.stack([
+        math_ops.range(8) for _ in range(batch_size)])
+    args = [x, constant_op.constant(step, shape=(batch_size,))]
+    expected = array_ops.stack([
+        math_ops.range(8)[::step] for _ in range(batch_size)])
+
+    if keras_tensor.keras_tensors_enabled():
+      self.assertIn('tf.__operators__.getitem', (
+          x.name for x in model.layers))
+      self.assertNotIn('tf.strided_slice', (
+          x.name for x in model.layers))
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+    # Make sure it can be successfully saved and loaded
+    config = model.get_config()
+    model = keras.Model.from_config(config)
+
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+  def test_getitem_slice_real_tensor(self):
+    if not context.executing_eagerly():
+      self.skipTest('Complex slicing like this fails in v1')
+    x = math_ops.range(10.0)
+    slice_stop = keras.Input(shape=(), dtype='int32')
+
+    out = x[:slice_stop[0]]
+    model = keras.Model(
+        inputs=slice_stop,
+        outputs=out)
+    model.compile(
+        adam.Adam(0.001),
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    batch_size = 7
+    stop = 6
+    args = constant_op.constant(stop, shape=(batch_size,))
+    expected = x[:stop]
+
+    if keras_tensor.keras_tensors_enabled():
+      self.assertIn('tf.__operators__.getitem', (
+          x.name for x in model.layers))
+      # TODO(b/161925288): Fix the dispatch triggering then uncomment:
+      # self.assertNotIn('tf.strided_slice', (
+      #     x.name for x in model.layers))
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+    # TODO(b/161925288): Fix the bug then uncomment:
+    # # Make sure it can be successfully saved and loaded
+    # config = model.get_config()
+    # model = keras.Model.from_config(config)
+
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+  def test_getitem_index_real_tensor(self):
+    if not context.executing_eagerly():
+      self.skipTest('Complex slicing like this fails in v1')
+    x = math_ops.range(10.0)
+    slice_stop = keras.Input(shape=(), dtype='int32')
+
+    out = x[slice_stop[0]]
+    model = keras.Model(
+        inputs=slice_stop,
+        outputs=out)
+    model.compile(
+        adam.Adam(0.001),
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    batch_size = 7
+    index = 6
+    args = constant_op.constant(index, shape=(batch_size,))
+    expected = x[index]
+
+    if keras_tensor.keras_tensors_enabled():
+      self.assertIn('tf.__operators__.getitem', (
+          x.name for x in model.layers))
+      # TODO(b/161925288): Fix the bug then uncomment:
+      # self.assertNotIn('tf.strided_slice', (
+      #     x.name for x in model.layers))
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+    # TODO(b/161925288): Fix the bug then uncomment:
+    # # Make sure it can be successfully saved and loaded
+    # config = model.get_config()
+    # model = keras.Model.from_config(config)
+
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+  def test_getitem_slice_with_stop_only(self):
+    if not context.executing_eagerly():
+      self.skipTest('Complex slicing like this fails in v1')
+    inp = keras.Input(shape=(4, 3, 8))
+    slice_stop = keras.Input(shape=(), dtype='int32')
+
+    out = inp[:slice_stop[0]]
+    model = keras.Model(
+        inputs=[inp, slice_stop],
+        outputs=out)
+    model.compile(
+        adam.Adam(0.001),
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    batch_size = 7
+    stop = 6
+    x = array_ops.stack([
+        math_ops.range(8) for _ in range(batch_size)])
+    args = [x, constant_op.constant(stop, shape=(batch_size,))]
+    expected = x[:stop]
+
+    if keras_tensor.keras_tensors_enabled():
+      self.assertIn('tf.__operators__.getitem', (
+          x.name for x in model.layers))
+      self.assertNotIn('tf.strided_slice', (
+          x.name for x in model.layers))
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+    # Make sure it can be successfully saved and loaded
+    config = model.get_config()
+    model = keras.Model.from_config(config)
+
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+  def test_getitem_slice_with_stop_and_ellipsis_only(self):
+    if not context.executing_eagerly():
+      self.skipTest('Complex slicing like this fails in v1')
+    inp = keras.Input(shape=(4, 3, 8))
+    slice_stop = keras.Input(shape=(), dtype='int32')
+
+    out = inp[..., :slice_stop[0]]
+    model = keras.Model(
+        inputs=[inp, slice_stop],
+        outputs=out)
+    model.compile(
+        adam.Adam(0.001),
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    batch_size = 7
+    stop = 6
+    x = array_ops.stack([
+        math_ops.range(8) for _ in range(batch_size)])
+    args = [x, constant_op.constant(stop, shape=(batch_size,))]
+    expected = array_ops.stack([
+        math_ops.range(8)[:stop] for _ in range(batch_size)])
+
+    if keras_tensor.keras_tensors_enabled():
+      self.assertIn('tf.__operators__.getitem', (
+          x.name for x in model.layers))
+      self.assertNotIn('tf.strided_slice', (
+          x.name for x in model.layers))
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+    # Make sure it can be successfully saved and loaded
+    config = model.get_config()
+    model = keras.Model.from_config(config)
+
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+  def test_getitem_complex_slicing(self):
+    if not context.executing_eagerly():
+      self.skipTest('Complex slicing like this fails in v1')
+    inp = keras.Input(shape=(4, 3, 8))
+    first_dim = keras.Input(shape=(), dtype='int32')
+    slice_start = keras.Input(shape=(), dtype='int32')
+    slice_stop = keras.Input(shape=(), dtype='int32')
+    slice_stride = keras.Input(shape=(), dtype='int32')
+
+    out = inp[..., first_dim[0], slice_start[0]:slice_stop[0]:slice_stride[0]]
+    model = keras.Model(
+        inputs=[inp, first_dim, slice_start, slice_stop, slice_stride],
+        outputs=out)
+    model.compile(
+        adam.Adam(0.001),
+        'mse',
+        run_eagerly=testing_utils.should_run_eagerly())
+    batch_size = 7
+    start = 1
+    stop = 6
+    step = 2
+    x = array_ops.stack([array_ops.stack([array_ops.stack([
+        math_ops.range(8)
+        for _ in range(3)]) for _ in range(4)]) for _ in range(batch_size)])
+    args = [x,
+            constant_op.constant(0, shape=(batch_size,)),
+            constant_op.constant(start, shape=(batch_size,)),
+            constant_op.constant(stop, shape=(batch_size,)),
+            constant_op.constant(step, shape=(batch_size,))]
+    # Slice the innermost dim. only grab one index from the second-to-innermost
+    # dim, removing that dim from the shape.
+    expected = array_ops.stack([array_ops.stack([
+        math_ops.range(8)[start:stop:step]
+        for _ in range(4)]) for _ in range(batch_size)])
+
+    if keras_tensor.keras_tensors_enabled():
+      self.assertIn('tf.__operators__.getitem', (
+          x.name for x in model.layers))
+      self.assertNotIn('tf.strided_slice', (
+          x.name for x in model.layers))
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
+    # Make sure it can be successfully saved and loaded
+    config = model.get_config()
+    model = keras.Model.from_config(config)
+
+    self.assertAllEqual(model(args), expected)
+    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
+
   def test_numerical_correctness_simple(self):
     x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
     inputs = keras.Input(shape=(4,))
diff --git a/tensorflow/python/keras/saving/saved_model/json_utils.py b/tensorflow/python/keras/saving/saved_model/json_utils.py
index 4e4b671697a..d06e4180564 100644
--- a/tensorflow/python/keras/saving/saved_model/json_utils.py
+++ b/tensorflow/python/keras/saving/saved_model/json_utils.py
@@ -70,11 +70,14 @@ def decode(json_string):
 
 
 def _decode_helper(obj):
+  """A decoding helper that is TF-object aware."""
   if isinstance(obj, dict) and 'class_name' in obj:
     if obj['class_name'] == 'TensorShape':
       return tensor_shape.TensorShape(obj['items'])
     elif obj['class_name'] == '__tuple__':
       return tuple(_decode_helper(i) for i in obj['items'])
+    elif obj['class_name'] == '__ellipsis__':
+      return Ellipsis
   return obj
 
 
@@ -122,6 +125,9 @@ def get_json_type(obj):
   if isinstance(obj, collections_abc.Mapping):
     return dict(obj)
 
+  if obj is Ellipsis:
+    return {'class_name': '__ellipsis__'}
+
   if isinstance(obj, wrapt.ObjectProxy):
     return obj.__wrapped__
 
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 8e9bc1ef4d3..e9f32dec6b8 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -955,6 +955,8 @@ def _slice_helper(tensor, slice_spec, var=None):
     TypeError: If the slice indices aren't int, slice, ellipsis,
       tf.newaxis or scalar int32/int64 tensors.
   """
+  tensor = ops.convert_to_tensor(tensor)
+
   if isinstance(slice_spec, bool) or \
   (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
   (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py
index cc4fed0abb7..2b3946ce9f7 100644
--- a/tensorflow/python/util/dispatch_test.py
+++ b/tensorflow/python/util/dispatch_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.proto_ops import decode_proto
@@ -28,6 +29,7 @@ from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import dispatch
+from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
 from tensorflow.python.util.tf_export import tf_export
 
@@ -68,10 +70,38 @@ class TensorTracer(object):
           ["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
       return "{}({})".format(self.name, ", ".join(args))
 
+  @classmethod
+  def _overload_all_operators(cls):  # pylint: disable=invalid-name
+    """Register overloads for all operators."""
+    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
+      cls._overload_operator(operator)
+
+  @classmethod
+  def _overload_operator(cls, operator):  # pylint: disable=invalid-name
+    """Overload an operator with the same overloading as `ops.Tensor`."""
+    tensor_oper = getattr(ops.Tensor, operator)
+
+    # Compatibility with Python 2:
+    # Python 2 unbound methods have type checks for the first arg,
+    # so we need to extract the underlying function
+    tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)
+    setattr(cls, operator, tensor_oper)
+
+TensorTracer._overload_all_operators()  # pylint: disable=protected-access
+
 
 class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
   """Global op dispatcher for TensorTracer."""
 
+  def _flatten_with_slice_flattening(self, x):
+    flat = []
+    for val in nest.flatten(x):
+      if isinstance(val, slice):
+        flat.extend((val.start, val.stop, val.step))
+      else:
+        flat.append(val)
+    return flat
+
   def handle(self, op, args, kwargs):
     # Dispatcher only applies if at least one arg is a TensorTracer.
     if not (any(self.is_tensor_tracer_arg(x) for x in args) or
@@ -82,11 +112,8 @@ class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
     return TensorTracer(symbol_name, args, kwargs)
 
   def is_tensor_tracer_arg(self, value):
-    if isinstance(value, TensorTracer):
-      return True
-    if isinstance(value, (list, tuple)):
-      if any(isinstance(x, TensorTracer) for x in value):
-        return True
+    return any(isinstance(x, TensorTracer) for x in
+               self._flatten_with_slice_flattening(value))
 
 
 @test_util.run_all_in_graph_and_eager_modes
@@ -214,5 +241,46 @@ class DispatchTest(test_util.TensorFlowTestCase):
       # Clean up.
       dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
 
+  def testGlobalDispatcherGetItem(self):
+    original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
+    try:
+      TensorTracerOpDispatcher().register()
+
+      x = TensorTracer("x")
+      trace = x[0]
+      self.assertEqual(
+          str(trace),
+          "__operators__.getitem(x, 0)")
+
+      x = TensorTracer("x")
+      y = TensorTracer("y")
+      trace = x[y]
+      self.assertEqual(
+          str(trace),
+          "__operators__.getitem(x, y)")
+
+      x = TensorTracer("x")
+      y = TensorTracer("y")
+      trace = x[:y]  # pylint: disable=invalid-slice-index
+      self.assertEqual(
+          str(trace),
+          "__operators__.getitem(x, slice(None, y, None))")
+
+      x = array_ops.ones(shape=(3, 3))
+      y = TensorTracer("y")
+      trace = x[y]
+      self.assertEqual(
+          str(trace),
+          "__operators__.getitem(%s, y)" % x)
+
+      trace = x[:y]  # pylint: disable=invalid-slice-index
+      self.assertEqual(
+          str(trace),
+          "__operators__.getitem(%s, slice(None, y, None))" % x)
+
+    finally:
+      # Clean up.
+      dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py
index 3b1713b4c61..e35d5ff5d5d 100644
--- a/tensorflow/python/util/serialization.py
+++ b/tensorflow/python/util/serialization.py
@@ -70,6 +70,9 @@ def get_json_type(obj):
   if isinstance(obj, collections_abc.Mapping):
     return dict(obj)
 
+  if obj is Ellipsis:
+    return {'class_name': '__ellipsis__'}
+
   if isinstance(obj, wrapt.ObjectProxy):
     return obj.__wrapped__