Add a convert_to_tensor to the start of Tensor.__getitem__ (_slice_helper) to make sure it dispatches directly, rather than letting the nested tf.strided_slice trigger dispatching.

This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem.

(Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code)

To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to:

* Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices.

* Add serialization/deserialization support for `ellipsis` objects in Keras

------------------------
Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors:

* add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not.

* Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be.

PiperOrigin-RevId: 322655930
Change-Id: I35417577199393c016f753be685bf2926d62e753
This commit is contained in:
Tomer Kaftan 2020-07-22 14:30:00 -07:00 committed by TensorFlower Gardener
parent e2f7c83bcb
commit e3be70aa9d
6 changed files with 411 additions and 5 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import operator
@ -1422,3 +1423,96 @@ class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
return self.NOT_SUPPORTED
KerasOpDispatcher().register()
SliceTuple = collections.namedtuple('SliceTuple', ['start', 'stop', 'step'])
def _slice_to_named_tuple(x):
if isinstance(x, slice):
return SliceTuple(x.start, x.stop, x.step)
return x
def _named_tuple_to_slice(x):
if type(x).__name__ == 'SliceTuple':
return slice(x[0], x[1], x[2])
return x
class SlicingOpLambda(TFOpLambda):
"""Wraps TF API symbols in a `Layer` object.
It is inserted by the Functional API construction whenever users call
a supported TF symbol on KerasTensors.
Like Lambda layers, this layer tries to raise warnings when it detects users
explicitly use variables in the call. (To let them know
that the layer will not capture the variables).
This is useful in the case where users do something like:
x = keras.Input(...)
y = tf.Variable(...)
out = x * tf_variable
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, function, **kwargs):
super(SlicingOpLambda, self).__init__(function, **kwargs)
original_call = self.call
# Decorate the function to produce this layer's call method
def _call_wrapper(*args, **kwargs):
# Turn any slice nametuples in the args back into `slice` objects.
# This conversion cannot use nest.flatten/map_structure,
# because namedtuples are flattened by nest while slices aren't.
# So, map_structure would only see the individual elements in the
# namedtuple.
# This can't use map_structure_up_to either because the 'shallowness' of
# the shallow tree would have to vary depending on if only one dim or
# multiple are being sliced.
new_args = []
for arg in args:
arg = _named_tuple_to_slice(arg)
if isinstance(arg, (list, tuple)):
new_arg = []
for sub_arg in arg:
new_arg.append(_named_tuple_to_slice(sub_arg))
arg = new_arg
new_args.append(arg)
# Handle the kwargs too.
new_kwargs = {}
for key, value in kwargs.items():
value = _named_tuple_to_slice(value)
if isinstance(value, (list, tuple)):
new_value = []
for v in value:
new_value.append(_named_tuple_to_slice(v))
value = new_value
new_kwargs[key] = value
return original_call(*new_args, **new_kwargs)
self.call = tf_decorator.make_decorator(original_call, _call_wrapper)
class TFSlicingOpDispatcher(dispatch.OpDispatcher):
"""A global dispatcher that allows building a functional model with TF Ops."""
def __init__(self, op):
self.op = op
def handle(self, args, kwargs):
"""Handle the specified operation with the specified arguments."""
args = nest.map_structure(_slice_to_named_tuple, args)
kwargs = nest.map_structure(_slice_to_named_tuple, kwargs)
if any(
isinstance(x, keras_tensor.KerasTensor)
for x in nest.flatten([args, kwargs])):
return SlicingOpLambda(self.op)(*args, **kwargs)
else:
return self.NOT_SUPPORTED
for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access
array_ops.boolean_mask,
array_ops.boolean_mask_v2]:
TFSlicingOpDispatcher(slicing_op).register(slicing_op)

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.saving import model_config
from tensorflow.python.ops import array_ops
@ -294,6 +295,238 @@ class AutoLambdaTest(keras_parameterized.TestCase):
self.assertAllEqual([layer.name for layer in model.layers],
[layer.name for layer in new_model.layers])
def test_getitem_slice_with_step_only(self):
if not context.executing_eagerly():
self.skipTest('Complex slicing like this fails in v1')
inp = keras.Input(shape=(4, 3, 8))
slice_step = keras.Input(shape=(), dtype='int32')
out = inp[..., ::slice_step[0]]
model = keras.Model(
inputs=[inp, slice_step],
outputs=out)
model.compile(
adam.Adam(0.001),
'mse',
run_eagerly=testing_utils.should_run_eagerly())
batch_size = 7
step = 3
x = array_ops.stack([
math_ops.range(8) for _ in range(batch_size)])
args = [x, constant_op.constant(step, shape=(batch_size,))]
expected = array_ops.stack([
math_ops.range(8)[::step] for _ in range(batch_size)])
if keras_tensor.keras_tensors_enabled():
self.assertIn('tf.__operators__.getitem', (
x.name for x in model.layers))
self.assertNotIn('tf.strided_slice', (
x.name for x in model.layers))
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
# Make sure it can be successfully saved and loaded
config = model.get_config()
model = keras.Model.from_config(config)
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def test_getitem_slice_real_tensor(self):
if not context.executing_eagerly():
self.skipTest('Complex slicing like this fails in v1')
x = math_ops.range(10.0)
slice_stop = keras.Input(shape=(), dtype='int32')
out = x[:slice_stop[0]]
model = keras.Model(
inputs=slice_stop,
outputs=out)
model.compile(
adam.Adam(0.001),
'mse',
run_eagerly=testing_utils.should_run_eagerly())
batch_size = 7
stop = 6
args = constant_op.constant(stop, shape=(batch_size,))
expected = x[:stop]
if keras_tensor.keras_tensors_enabled():
self.assertIn('tf.__operators__.getitem', (
x.name for x in model.layers))
# TODO(b/161925288): Fix the dispatch triggering then uncomment:
# self.assertNotIn('tf.strided_slice', (
# x.name for x in model.layers))
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
# TODO(b/161925288): Fix the bug then uncomment:
# # Make sure it can be successfully saved and loaded
# config = model.get_config()
# model = keras.Model.from_config(config)
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def test_getitem_index_real_tensor(self):
if not context.executing_eagerly():
self.skipTest('Complex slicing like this fails in v1')
x = math_ops.range(10.0)
slice_stop = keras.Input(shape=(), dtype='int32')
out = x[slice_stop[0]]
model = keras.Model(
inputs=slice_stop,
outputs=out)
model.compile(
adam.Adam(0.001),
'mse',
run_eagerly=testing_utils.should_run_eagerly())
batch_size = 7
index = 6
args = constant_op.constant(index, shape=(batch_size,))
expected = x[index]
if keras_tensor.keras_tensors_enabled():
self.assertIn('tf.__operators__.getitem', (
x.name for x in model.layers))
# TODO(b/161925288): Fix the bug then uncomment:
# self.assertNotIn('tf.strided_slice', (
# x.name for x in model.layers))
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
# TODO(b/161925288): Fix the bug then uncomment:
# # Make sure it can be successfully saved and loaded
# config = model.get_config()
# model = keras.Model.from_config(config)
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def test_getitem_slice_with_stop_only(self):
if not context.executing_eagerly():
self.skipTest('Complex slicing like this fails in v1')
inp = keras.Input(shape=(4, 3, 8))
slice_stop = keras.Input(shape=(), dtype='int32')
out = inp[:slice_stop[0]]
model = keras.Model(
inputs=[inp, slice_stop],
outputs=out)
model.compile(
adam.Adam(0.001),
'mse',
run_eagerly=testing_utils.should_run_eagerly())
batch_size = 7
stop = 6
x = array_ops.stack([
math_ops.range(8) for _ in range(batch_size)])
args = [x, constant_op.constant(stop, shape=(batch_size,))]
expected = x[:stop]
if keras_tensor.keras_tensors_enabled():
self.assertIn('tf.__operators__.getitem', (
x.name for x in model.layers))
self.assertNotIn('tf.strided_slice', (
x.name for x in model.layers))
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
# Make sure it can be successfully saved and loaded
config = model.get_config()
model = keras.Model.from_config(config)
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def test_getitem_slice_with_stop_and_ellipsis_only(self):
if not context.executing_eagerly():
self.skipTest('Complex slicing like this fails in v1')
inp = keras.Input(shape=(4, 3, 8))
slice_stop = keras.Input(shape=(), dtype='int32')
out = inp[..., :slice_stop[0]]
model = keras.Model(
inputs=[inp, slice_stop],
outputs=out)
model.compile(
adam.Adam(0.001),
'mse',
run_eagerly=testing_utils.should_run_eagerly())
batch_size = 7
stop = 6
x = array_ops.stack([
math_ops.range(8) for _ in range(batch_size)])
args = [x, constant_op.constant(stop, shape=(batch_size,))]
expected = array_ops.stack([
math_ops.range(8)[:stop] for _ in range(batch_size)])
if keras_tensor.keras_tensors_enabled():
self.assertIn('tf.__operators__.getitem', (
x.name for x in model.layers))
self.assertNotIn('tf.strided_slice', (
x.name for x in model.layers))
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
# Make sure it can be successfully saved and loaded
config = model.get_config()
model = keras.Model.from_config(config)
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def test_getitem_complex_slicing(self):
if not context.executing_eagerly():
self.skipTest('Complex slicing like this fails in v1')
inp = keras.Input(shape=(4, 3, 8))
first_dim = keras.Input(shape=(), dtype='int32')
slice_start = keras.Input(shape=(), dtype='int32')
slice_stop = keras.Input(shape=(), dtype='int32')
slice_stride = keras.Input(shape=(), dtype='int32')
out = inp[..., first_dim[0], slice_start[0]:slice_stop[0]:slice_stride[0]]
model = keras.Model(
inputs=[inp, first_dim, slice_start, slice_stop, slice_stride],
outputs=out)
model.compile(
adam.Adam(0.001),
'mse',
run_eagerly=testing_utils.should_run_eagerly())
batch_size = 7
start = 1
stop = 6
step = 2
x = array_ops.stack([array_ops.stack([array_ops.stack([
math_ops.range(8)
for _ in range(3)]) for _ in range(4)]) for _ in range(batch_size)])
args = [x,
constant_op.constant(0, shape=(batch_size,)),
constant_op.constant(start, shape=(batch_size,)),
constant_op.constant(stop, shape=(batch_size,)),
constant_op.constant(step, shape=(batch_size,))]
# Slice the innermost dim. only grab one index from the second-to-innermost
# dim, removing that dim from the shape.
expected = array_ops.stack([array_ops.stack([
math_ops.range(8)[start:stop:step]
for _ in range(4)]) for _ in range(batch_size)])
if keras_tensor.keras_tensors_enabled():
self.assertIn('tf.__operators__.getitem', (
x.name for x in model.layers))
self.assertNotIn('tf.strided_slice', (
x.name for x in model.layers))
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
# Make sure it can be successfully saved and loaded
config = model.get_config()
model = keras.Model.from_config(config)
self.assertAllEqual(model(args), expected)
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def test_numerical_correctness_simple(self):
x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
inputs = keras.Input(shape=(4,))

View File

@ -70,11 +70,14 @@ def decode(json_string):
def _decode_helper(obj):
"""A decoding helper that is TF-object aware."""
if isinstance(obj, dict) and 'class_name' in obj:
if obj['class_name'] == 'TensorShape':
return tensor_shape.TensorShape(obj['items'])
elif obj['class_name'] == '__tuple__':
return tuple(_decode_helper(i) for i in obj['items'])
elif obj['class_name'] == '__ellipsis__':
return Ellipsis
return obj
@ -122,6 +125,9 @@ def get_json_type(obj):
if isinstance(obj, collections_abc.Mapping):
return dict(obj)
if obj is Ellipsis:
return {'class_name': '__ellipsis__'}
if isinstance(obj, wrapt.ObjectProxy):
return obj.__wrapped__

View File

@ -955,6 +955,8 @@ def _slice_helper(tensor, slice_spec, var=None):
TypeError: If the slice indices aren't int, slice, ellipsis,
tf.newaxis or scalar int32/int64 tensors.
"""
tensor = ops.convert_to_tensor(tensor)
if isinstance(slice_spec, bool) or \
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.proto_ops import decode_proto
@ -28,6 +29,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
from tensorflow.python.util.tf_export import tf_export
@ -68,10 +70,38 @@ class TensorTracer(object):
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
return "{}({})".format(self.name, ", ".join(args))
@classmethod
def _overload_all_operators(cls): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
cls._overload_operator(operator)
@classmethod
def _overload_operator(cls, operator): # pylint: disable=invalid-name
"""Overload an operator with the same overloading as `ops.Tensor`."""
tensor_oper = getattr(ops.Tensor, operator)
# Compatibility with Python 2:
# Python 2 unbound methods have type checks for the first arg,
# so we need to extract the underlying function
tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)
setattr(cls, operator, tensor_oper)
TensorTracer._overload_all_operators() # pylint: disable=protected-access
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
"""Global op dispatcher for TensorTracer."""
def _flatten_with_slice_flattening(self, x):
flat = []
for val in nest.flatten(x):
if isinstance(val, slice):
flat.extend((val.start, val.stop, val.step))
else:
flat.append(val)
return flat
def handle(self, op, args, kwargs):
# Dispatcher only applies if at least one arg is a TensorTracer.
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
@ -82,11 +112,8 @@ class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
return TensorTracer(symbol_name, args, kwargs)
def is_tensor_tracer_arg(self, value):
if isinstance(value, TensorTracer):
return True
if isinstance(value, (list, tuple)):
if any(isinstance(x, TensorTracer) for x in value):
return True
return any(isinstance(x, TensorTracer) for x in
self._flatten_with_slice_flattening(value))
@test_util.run_all_in_graph_and_eager_modes
@ -214,5 +241,46 @@ class DispatchTest(test_util.TensorFlowTestCase):
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
def testGlobalDispatcherGetItem(self):
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
try:
TensorTracerOpDispatcher().register()
x = TensorTracer("x")
trace = x[0]
self.assertEqual(
str(trace),
"__operators__.getitem(x, 0)")
x = TensorTracer("x")
y = TensorTracer("y")
trace = x[y]
self.assertEqual(
str(trace),
"__operators__.getitem(x, y)")
x = TensorTracer("x")
y = TensorTracer("y")
trace = x[:y] # pylint: disable=invalid-slice-index
self.assertEqual(
str(trace),
"__operators__.getitem(x, slice(None, y, None))")
x = array_ops.ones(shape=(3, 3))
y = TensorTracer("y")
trace = x[y]
self.assertEqual(
str(trace),
"__operators__.getitem(%s, y)" % x)
trace = x[:y] # pylint: disable=invalid-slice-index
self.assertEqual(
str(trace),
"__operators__.getitem(%s, slice(None, y, None))" % x)
finally:
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
if __name__ == "__main__":
googletest.main()

View File

@ -70,6 +70,9 @@ def get_json_type(obj):
if isinstance(obj, collections_abc.Mapping):
return dict(obj)
if obj is Ellipsis:
return {'class_name': '__ellipsis__'}
if isinstance(obj, wrapt.ObjectProxy):
return obj.__wrapped__