Add a convert_to_tensor
to the start of Tensor.__getitem__ (_slice_helper) to make sure it dispatches directly, rather than letting the nested tf.strided_slice trigger dispatching.
This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem. (Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code) To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to: * Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices. * Add serialization/deserialization support for `ellipsis` objects in Keras ------------------------ Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors: * add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not. * Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be. PiperOrigin-RevId: 322655930 Change-Id: I35417577199393c016f753be685bf2926d62e753
This commit is contained in:
parent
e2f7c83bcb
commit
e3be70aa9d
tensorflow/python
keras
ops
util
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import operator
|
import operator
|
||||||
@ -1422,3 +1423,96 @@ class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
|
|||||||
return self.NOT_SUPPORTED
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
KerasOpDispatcher().register()
|
KerasOpDispatcher().register()
|
||||||
|
|
||||||
|
SliceTuple = collections.namedtuple('SliceTuple', ['start', 'stop', 'step'])
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_to_named_tuple(x):
|
||||||
|
if isinstance(x, slice):
|
||||||
|
return SliceTuple(x.start, x.stop, x.step)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _named_tuple_to_slice(x):
|
||||||
|
if type(x).__name__ == 'SliceTuple':
|
||||||
|
return slice(x[0], x[1], x[2])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SlicingOpLambda(TFOpLambda):
|
||||||
|
"""Wraps TF API symbols in a `Layer` object.
|
||||||
|
|
||||||
|
It is inserted by the Functional API construction whenever users call
|
||||||
|
a supported TF symbol on KerasTensors.
|
||||||
|
|
||||||
|
Like Lambda layers, this layer tries to raise warnings when it detects users
|
||||||
|
explicitly use variables in the call. (To let them know
|
||||||
|
that the layer will not capture the variables).
|
||||||
|
|
||||||
|
This is useful in the case where users do something like:
|
||||||
|
x = keras.Input(...)
|
||||||
|
y = tf.Variable(...)
|
||||||
|
out = x * tf_variable
|
||||||
|
"""
|
||||||
|
|
||||||
|
@trackable.no_automatic_dependency_tracking
|
||||||
|
def __init__(self, function, **kwargs):
|
||||||
|
super(SlicingOpLambda, self).__init__(function, **kwargs)
|
||||||
|
|
||||||
|
original_call = self.call
|
||||||
|
# Decorate the function to produce this layer's call method
|
||||||
|
def _call_wrapper(*args, **kwargs):
|
||||||
|
# Turn any slice nametuples in the args back into `slice` objects.
|
||||||
|
# This conversion cannot use nest.flatten/map_structure,
|
||||||
|
# because namedtuples are flattened by nest while slices aren't.
|
||||||
|
# So, map_structure would only see the individual elements in the
|
||||||
|
# namedtuple.
|
||||||
|
# This can't use map_structure_up_to either because the 'shallowness' of
|
||||||
|
# the shallow tree would have to vary depending on if only one dim or
|
||||||
|
# multiple are being sliced.
|
||||||
|
new_args = []
|
||||||
|
for arg in args:
|
||||||
|
arg = _named_tuple_to_slice(arg)
|
||||||
|
if isinstance(arg, (list, tuple)):
|
||||||
|
new_arg = []
|
||||||
|
for sub_arg in arg:
|
||||||
|
new_arg.append(_named_tuple_to_slice(sub_arg))
|
||||||
|
arg = new_arg
|
||||||
|
new_args.append(arg)
|
||||||
|
|
||||||
|
# Handle the kwargs too.
|
||||||
|
new_kwargs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
value = _named_tuple_to_slice(value)
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
new_value = []
|
||||||
|
for v in value:
|
||||||
|
new_value.append(_named_tuple_to_slice(v))
|
||||||
|
value = new_value
|
||||||
|
new_kwargs[key] = value
|
||||||
|
|
||||||
|
return original_call(*new_args, **new_kwargs)
|
||||||
|
self.call = tf_decorator.make_decorator(original_call, _call_wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
class TFSlicingOpDispatcher(dispatch.OpDispatcher):
|
||||||
|
"""A global dispatcher that allows building a functional model with TF Ops."""
|
||||||
|
|
||||||
|
def __init__(self, op):
|
||||||
|
self.op = op
|
||||||
|
|
||||||
|
def handle(self, args, kwargs):
|
||||||
|
"""Handle the specified operation with the specified arguments."""
|
||||||
|
args = nest.map_structure(_slice_to_named_tuple, args)
|
||||||
|
kwargs = nest.map_structure(_slice_to_named_tuple, kwargs)
|
||||||
|
if any(
|
||||||
|
isinstance(x, keras_tensor.KerasTensor)
|
||||||
|
for x in nest.flatten([args, kwargs])):
|
||||||
|
return SlicingOpLambda(self.op)(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access
|
||||||
|
array_ops.boolean_mask,
|
||||||
|
array_ops.boolean_mask_v2]:
|
||||||
|
TFSlicingOpDispatcher(slicing_op).register(slicing_op)
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
from tensorflow.python.keras.saving import model_config
|
from tensorflow.python.keras.saving import model_config
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -294,6 +295,238 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual([layer.name for layer in model.layers],
|
self.assertAllEqual([layer.name for layer in model.layers],
|
||||||
[layer.name for layer in new_model.layers])
|
[layer.name for layer in new_model.layers])
|
||||||
|
|
||||||
|
def test_getitem_slice_with_step_only(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
|
inp = keras.Input(shape=(4, 3, 8))
|
||||||
|
slice_step = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
|
out = inp[..., ::slice_step[0]]
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=[inp, slice_step],
|
||||||
|
outputs=out)
|
||||||
|
model.compile(
|
||||||
|
adam.Adam(0.001),
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
batch_size = 7
|
||||||
|
step = 3
|
||||||
|
x = array_ops.stack([
|
||||||
|
math_ops.range(8) for _ in range(batch_size)])
|
||||||
|
args = [x, constant_op.constant(step, shape=(batch_size,))]
|
||||||
|
expected = array_ops.stack([
|
||||||
|
math_ops.range(8)[::step] for _ in range(batch_size)])
|
||||||
|
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
self.assertIn('tf.__operators__.getitem', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertNotIn('tf.strided_slice', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
# Make sure it can be successfully saved and loaded
|
||||||
|
config = model.get_config()
|
||||||
|
model = keras.Model.from_config(config)
|
||||||
|
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
def test_getitem_slice_real_tensor(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
|
x = math_ops.range(10.0)
|
||||||
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
|
out = x[:slice_stop[0]]
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=slice_stop,
|
||||||
|
outputs=out)
|
||||||
|
model.compile(
|
||||||
|
adam.Adam(0.001),
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
batch_size = 7
|
||||||
|
stop = 6
|
||||||
|
args = constant_op.constant(stop, shape=(batch_size,))
|
||||||
|
expected = x[:stop]
|
||||||
|
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
self.assertIn('tf.__operators__.getitem', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
# TODO(b/161925288): Fix the dispatch triggering then uncomment:
|
||||||
|
# self.assertNotIn('tf.strided_slice', (
|
||||||
|
# x.name for x in model.layers))
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
# TODO(b/161925288): Fix the bug then uncomment:
|
||||||
|
# # Make sure it can be successfully saved and loaded
|
||||||
|
# config = model.get_config()
|
||||||
|
# model = keras.Model.from_config(config)
|
||||||
|
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
def test_getitem_index_real_tensor(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
|
x = math_ops.range(10.0)
|
||||||
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
|
out = x[slice_stop[0]]
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=slice_stop,
|
||||||
|
outputs=out)
|
||||||
|
model.compile(
|
||||||
|
adam.Adam(0.001),
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
batch_size = 7
|
||||||
|
index = 6
|
||||||
|
args = constant_op.constant(index, shape=(batch_size,))
|
||||||
|
expected = x[index]
|
||||||
|
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
self.assertIn('tf.__operators__.getitem', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
# TODO(b/161925288): Fix the bug then uncomment:
|
||||||
|
# self.assertNotIn('tf.strided_slice', (
|
||||||
|
# x.name for x in model.layers))
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
# TODO(b/161925288): Fix the bug then uncomment:
|
||||||
|
# # Make sure it can be successfully saved and loaded
|
||||||
|
# config = model.get_config()
|
||||||
|
# model = keras.Model.from_config(config)
|
||||||
|
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
def test_getitem_slice_with_stop_only(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
|
inp = keras.Input(shape=(4, 3, 8))
|
||||||
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
|
out = inp[:slice_stop[0]]
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=[inp, slice_stop],
|
||||||
|
outputs=out)
|
||||||
|
model.compile(
|
||||||
|
adam.Adam(0.001),
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
batch_size = 7
|
||||||
|
stop = 6
|
||||||
|
x = array_ops.stack([
|
||||||
|
math_ops.range(8) for _ in range(batch_size)])
|
||||||
|
args = [x, constant_op.constant(stop, shape=(batch_size,))]
|
||||||
|
expected = x[:stop]
|
||||||
|
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
self.assertIn('tf.__operators__.getitem', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertNotIn('tf.strided_slice', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
# Make sure it can be successfully saved and loaded
|
||||||
|
config = model.get_config()
|
||||||
|
model = keras.Model.from_config(config)
|
||||||
|
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
def test_getitem_slice_with_stop_and_ellipsis_only(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
|
inp = keras.Input(shape=(4, 3, 8))
|
||||||
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
|
out = inp[..., :slice_stop[0]]
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=[inp, slice_stop],
|
||||||
|
outputs=out)
|
||||||
|
model.compile(
|
||||||
|
adam.Adam(0.001),
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
batch_size = 7
|
||||||
|
stop = 6
|
||||||
|
x = array_ops.stack([
|
||||||
|
math_ops.range(8) for _ in range(batch_size)])
|
||||||
|
args = [x, constant_op.constant(stop, shape=(batch_size,))]
|
||||||
|
expected = array_ops.stack([
|
||||||
|
math_ops.range(8)[:stop] for _ in range(batch_size)])
|
||||||
|
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
self.assertIn('tf.__operators__.getitem', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertNotIn('tf.strided_slice', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
# Make sure it can be successfully saved and loaded
|
||||||
|
config = model.get_config()
|
||||||
|
model = keras.Model.from_config(config)
|
||||||
|
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
def test_getitem_complex_slicing(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest('Complex slicing like this fails in v1')
|
||||||
|
inp = keras.Input(shape=(4, 3, 8))
|
||||||
|
first_dim = keras.Input(shape=(), dtype='int32')
|
||||||
|
slice_start = keras.Input(shape=(), dtype='int32')
|
||||||
|
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||||
|
slice_stride = keras.Input(shape=(), dtype='int32')
|
||||||
|
|
||||||
|
out = inp[..., first_dim[0], slice_start[0]:slice_stop[0]:slice_stride[0]]
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=[inp, first_dim, slice_start, slice_stop, slice_stride],
|
||||||
|
outputs=out)
|
||||||
|
model.compile(
|
||||||
|
adam.Adam(0.001),
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
batch_size = 7
|
||||||
|
start = 1
|
||||||
|
stop = 6
|
||||||
|
step = 2
|
||||||
|
x = array_ops.stack([array_ops.stack([array_ops.stack([
|
||||||
|
math_ops.range(8)
|
||||||
|
for _ in range(3)]) for _ in range(4)]) for _ in range(batch_size)])
|
||||||
|
args = [x,
|
||||||
|
constant_op.constant(0, shape=(batch_size,)),
|
||||||
|
constant_op.constant(start, shape=(batch_size,)),
|
||||||
|
constant_op.constant(stop, shape=(batch_size,)),
|
||||||
|
constant_op.constant(step, shape=(batch_size,))]
|
||||||
|
# Slice the innermost dim. only grab one index from the second-to-innermost
|
||||||
|
# dim, removing that dim from the shape.
|
||||||
|
expected = array_ops.stack([array_ops.stack([
|
||||||
|
math_ops.range(8)[start:stop:step]
|
||||||
|
for _ in range(4)]) for _ in range(batch_size)])
|
||||||
|
|
||||||
|
if keras_tensor.keras_tensors_enabled():
|
||||||
|
self.assertIn('tf.__operators__.getitem', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertNotIn('tf.strided_slice', (
|
||||||
|
x.name for x in model.layers))
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
|
# Make sure it can be successfully saved and loaded
|
||||||
|
config = model.get_config()
|
||||||
|
model = keras.Model.from_config(config)
|
||||||
|
|
||||||
|
self.assertAllEqual(model(args), expected)
|
||||||
|
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||||
|
|
||||||
def test_numerical_correctness_simple(self):
|
def test_numerical_correctness_simple(self):
|
||||||
x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
|
x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
|
||||||
inputs = keras.Input(shape=(4,))
|
inputs = keras.Input(shape=(4,))
|
||||||
|
@ -70,11 +70,14 @@ def decode(json_string):
|
|||||||
|
|
||||||
|
|
||||||
def _decode_helper(obj):
|
def _decode_helper(obj):
|
||||||
|
"""A decoding helper that is TF-object aware."""
|
||||||
if isinstance(obj, dict) and 'class_name' in obj:
|
if isinstance(obj, dict) and 'class_name' in obj:
|
||||||
if obj['class_name'] == 'TensorShape':
|
if obj['class_name'] == 'TensorShape':
|
||||||
return tensor_shape.TensorShape(obj['items'])
|
return tensor_shape.TensorShape(obj['items'])
|
||||||
elif obj['class_name'] == '__tuple__':
|
elif obj['class_name'] == '__tuple__':
|
||||||
return tuple(_decode_helper(i) for i in obj['items'])
|
return tuple(_decode_helper(i) for i in obj['items'])
|
||||||
|
elif obj['class_name'] == '__ellipsis__':
|
||||||
|
return Ellipsis
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@ -122,6 +125,9 @@ def get_json_type(obj):
|
|||||||
if isinstance(obj, collections_abc.Mapping):
|
if isinstance(obj, collections_abc.Mapping):
|
||||||
return dict(obj)
|
return dict(obj)
|
||||||
|
|
||||||
|
if obj is Ellipsis:
|
||||||
|
return {'class_name': '__ellipsis__'}
|
||||||
|
|
||||||
if isinstance(obj, wrapt.ObjectProxy):
|
if isinstance(obj, wrapt.ObjectProxy):
|
||||||
return obj.__wrapped__
|
return obj.__wrapped__
|
||||||
|
|
||||||
|
@ -955,6 +955,8 @@ def _slice_helper(tensor, slice_spec, var=None):
|
|||||||
TypeError: If the slice indices aren't int, slice, ellipsis,
|
TypeError: If the slice indices aren't int, slice, ellipsis,
|
||||||
tf.newaxis or scalar int32/int64 tensors.
|
tf.newaxis or scalar int32/int64 tensors.
|
||||||
"""
|
"""
|
||||||
|
tensor = ops.convert_to_tensor(tensor)
|
||||||
|
|
||||||
if isinstance(slice_spec, bool) or \
|
if isinstance(slice_spec, bool) or \
|
||||||
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
|
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
|
||||||
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
|
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.proto_ops import decode_proto
|
from tensorflow.python.ops.proto_ops import decode_proto
|
||||||
@ -28,6 +29,7 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import dispatch
|
from tensorflow.python.util import dispatch
|
||||||
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -68,10 +70,38 @@ class TensorTracer(object):
|
|||||||
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
|
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
|
||||||
return "{}({})".format(self.name, ", ".join(args))
|
return "{}({})".format(self.name, ", ".join(args))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _overload_all_operators(cls): # pylint: disable=invalid-name
|
||||||
|
"""Register overloads for all operators."""
|
||||||
|
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
|
||||||
|
cls._overload_operator(operator)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _overload_operator(cls, operator): # pylint: disable=invalid-name
|
||||||
|
"""Overload an operator with the same overloading as `ops.Tensor`."""
|
||||||
|
tensor_oper = getattr(ops.Tensor, operator)
|
||||||
|
|
||||||
|
# Compatibility with Python 2:
|
||||||
|
# Python 2 unbound methods have type checks for the first arg,
|
||||||
|
# so we need to extract the underlying function
|
||||||
|
tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)
|
||||||
|
setattr(cls, operator, tensor_oper)
|
||||||
|
|
||||||
|
TensorTracer._overload_all_operators() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||||
"""Global op dispatcher for TensorTracer."""
|
"""Global op dispatcher for TensorTracer."""
|
||||||
|
|
||||||
|
def _flatten_with_slice_flattening(self, x):
|
||||||
|
flat = []
|
||||||
|
for val in nest.flatten(x):
|
||||||
|
if isinstance(val, slice):
|
||||||
|
flat.extend((val.start, val.stop, val.step))
|
||||||
|
else:
|
||||||
|
flat.append(val)
|
||||||
|
return flat
|
||||||
|
|
||||||
def handle(self, op, args, kwargs):
|
def handle(self, op, args, kwargs):
|
||||||
# Dispatcher only applies if at least one arg is a TensorTracer.
|
# Dispatcher only applies if at least one arg is a TensorTracer.
|
||||||
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
|
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
|
||||||
@ -82,11 +112,8 @@ class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
|||||||
return TensorTracer(symbol_name, args, kwargs)
|
return TensorTracer(symbol_name, args, kwargs)
|
||||||
|
|
||||||
def is_tensor_tracer_arg(self, value):
|
def is_tensor_tracer_arg(self, value):
|
||||||
if isinstance(value, TensorTracer):
|
return any(isinstance(x, TensorTracer) for x in
|
||||||
return True
|
self._flatten_with_slice_flattening(value))
|
||||||
if isinstance(value, (list, tuple)):
|
|
||||||
if any(isinstance(x, TensorTracer) for x in value):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
@ -214,5 +241,46 @@ class DispatchTest(test_util.TensorFlowTestCase):
|
|||||||
# Clean up.
|
# Clean up.
|
||||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||||
|
|
||||||
|
def testGlobalDispatcherGetItem(self):
|
||||||
|
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
|
||||||
|
try:
|
||||||
|
TensorTracerOpDispatcher().register()
|
||||||
|
|
||||||
|
x = TensorTracer("x")
|
||||||
|
trace = x[0]
|
||||||
|
self.assertEqual(
|
||||||
|
str(trace),
|
||||||
|
"__operators__.getitem(x, 0)")
|
||||||
|
|
||||||
|
x = TensorTracer("x")
|
||||||
|
y = TensorTracer("y")
|
||||||
|
trace = x[y]
|
||||||
|
self.assertEqual(
|
||||||
|
str(trace),
|
||||||
|
"__operators__.getitem(x, y)")
|
||||||
|
|
||||||
|
x = TensorTracer("x")
|
||||||
|
y = TensorTracer("y")
|
||||||
|
trace = x[:y] # pylint: disable=invalid-slice-index
|
||||||
|
self.assertEqual(
|
||||||
|
str(trace),
|
||||||
|
"__operators__.getitem(x, slice(None, y, None))")
|
||||||
|
|
||||||
|
x = array_ops.ones(shape=(3, 3))
|
||||||
|
y = TensorTracer("y")
|
||||||
|
trace = x[y]
|
||||||
|
self.assertEqual(
|
||||||
|
str(trace),
|
||||||
|
"__operators__.getitem(%s, y)" % x)
|
||||||
|
|
||||||
|
trace = x[:y] # pylint: disable=invalid-slice-index
|
||||||
|
self.assertEqual(
|
||||||
|
str(trace),
|
||||||
|
"__operators__.getitem(%s, slice(None, y, None))" % x)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up.
|
||||||
|
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -70,6 +70,9 @@ def get_json_type(obj):
|
|||||||
if isinstance(obj, collections_abc.Mapping):
|
if isinstance(obj, collections_abc.Mapping):
|
||||||
return dict(obj)
|
return dict(obj)
|
||||||
|
|
||||||
|
if obj is Ellipsis:
|
||||||
|
return {'class_name': '__ellipsis__'}
|
||||||
|
|
||||||
if isinstance(obj, wrapt.ObjectProxy):
|
if isinstance(obj, wrapt.ObjectProxy):
|
||||||
return obj.__wrapped__
|
return obj.__wrapped__
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user