Add a convert_to_tensor
to the start of Tensor.__getitem__ (_slice_helper) to make sure it dispatches directly, rather than letting the nested tf.strided_slice trigger dispatching.
This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem. (Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code) To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to: * Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices. * Add serialization/deserialization support for `ellipsis` objects in Keras ------------------------ Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors: * add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not. * Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be. PiperOrigin-RevId: 322655930 Change-Id: I35417577199393c016f753be685bf2926d62e753
This commit is contained in:
parent
e2f7c83bcb
commit
e3be70aa9d
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import operator
|
||||
@ -1422,3 +1423,96 @@ class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
KerasOpDispatcher().register()
|
||||
|
||||
SliceTuple = collections.namedtuple('SliceTuple', ['start', 'stop', 'step'])
|
||||
|
||||
|
||||
def _slice_to_named_tuple(x):
|
||||
if isinstance(x, slice):
|
||||
return SliceTuple(x.start, x.stop, x.step)
|
||||
return x
|
||||
|
||||
|
||||
def _named_tuple_to_slice(x):
|
||||
if type(x).__name__ == 'SliceTuple':
|
||||
return slice(x[0], x[1], x[2])
|
||||
return x
|
||||
|
||||
|
||||
class SlicingOpLambda(TFOpLambda):
|
||||
"""Wraps TF API symbols in a `Layer` object.
|
||||
|
||||
It is inserted by the Functional API construction whenever users call
|
||||
a supported TF symbol on KerasTensors.
|
||||
|
||||
Like Lambda layers, this layer tries to raise warnings when it detects users
|
||||
explicitly use variables in the call. (To let them know
|
||||
that the layer will not capture the variables).
|
||||
|
||||
This is useful in the case where users do something like:
|
||||
x = keras.Input(...)
|
||||
y = tf.Variable(...)
|
||||
out = x * tf_variable
|
||||
"""
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, function, **kwargs):
|
||||
super(SlicingOpLambda, self).__init__(function, **kwargs)
|
||||
|
||||
original_call = self.call
|
||||
# Decorate the function to produce this layer's call method
|
||||
def _call_wrapper(*args, **kwargs):
|
||||
# Turn any slice nametuples in the args back into `slice` objects.
|
||||
# This conversion cannot use nest.flatten/map_structure,
|
||||
# because namedtuples are flattened by nest while slices aren't.
|
||||
# So, map_structure would only see the individual elements in the
|
||||
# namedtuple.
|
||||
# This can't use map_structure_up_to either because the 'shallowness' of
|
||||
# the shallow tree would have to vary depending on if only one dim or
|
||||
# multiple are being sliced.
|
||||
new_args = []
|
||||
for arg in args:
|
||||
arg = _named_tuple_to_slice(arg)
|
||||
if isinstance(arg, (list, tuple)):
|
||||
new_arg = []
|
||||
for sub_arg in arg:
|
||||
new_arg.append(_named_tuple_to_slice(sub_arg))
|
||||
arg = new_arg
|
||||
new_args.append(arg)
|
||||
|
||||
# Handle the kwargs too.
|
||||
new_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
value = _named_tuple_to_slice(value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
new_value = []
|
||||
for v in value:
|
||||
new_value.append(_named_tuple_to_slice(v))
|
||||
value = new_value
|
||||
new_kwargs[key] = value
|
||||
|
||||
return original_call(*new_args, **new_kwargs)
|
||||
self.call = tf_decorator.make_decorator(original_call, _call_wrapper)
|
||||
|
||||
|
||||
class TFSlicingOpDispatcher(dispatch.OpDispatcher):
|
||||
"""A global dispatcher that allows building a functional model with TF Ops."""
|
||||
|
||||
def __init__(self, op):
|
||||
self.op = op
|
||||
|
||||
def handle(self, args, kwargs):
|
||||
"""Handle the specified operation with the specified arguments."""
|
||||
args = nest.map_structure(_slice_to_named_tuple, args)
|
||||
kwargs = nest.map_structure(_slice_to_named_tuple, kwargs)
|
||||
if any(
|
||||
isinstance(x, keras_tensor.KerasTensor)
|
||||
for x in nest.flatten([args, kwargs])):
|
||||
return SlicingOpLambda(self.op)(*args, **kwargs)
|
||||
else:
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access
|
||||
array_ops.boolean_mask,
|
||||
array_ops.boolean_mask_v2]:
|
||||
TFSlicingOpDispatcher(slicing_op).register(slicing_op)
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.keras.saving import model_config
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -294,6 +295,238 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
self.assertAllEqual([layer.name for layer in model.layers],
|
||||
[layer.name for layer in new_model.layers])
|
||||
|
||||
def test_getitem_slice_with_step_only(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
slice_step = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[..., ::slice_step[0]]
|
||||
model = keras.Model(
|
||||
inputs=[inp, slice_step],
|
||||
outputs=out)
|
||||
model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
batch_size = 7
|
||||
step = 3
|
||||
x = array_ops.stack([
|
||||
math_ops.range(8) for _ in range(batch_size)])
|
||||
args = [x, constant_op.constant(step, shape=(batch_size,))]
|
||||
expected = array_ops.stack([
|
||||
math_ops.range(8)[::step] for _ in range(batch_size)])
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
self.assertIn('tf.__operators__.getitem', (
|
||||
x.name for x in model.layers))
|
||||
self.assertNotIn('tf.strided_slice', (
|
||||
x.name for x in model.layers))
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
# Make sure it can be successfully saved and loaded
|
||||
config = model.get_config()
|
||||
model = keras.Model.from_config(config)
|
||||
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
def test_getitem_slice_real_tensor(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
x = math_ops.range(10.0)
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = x[:slice_stop[0]]
|
||||
model = keras.Model(
|
||||
inputs=slice_stop,
|
||||
outputs=out)
|
||||
model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
batch_size = 7
|
||||
stop = 6
|
||||
args = constant_op.constant(stop, shape=(batch_size,))
|
||||
expected = x[:stop]
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
self.assertIn('tf.__operators__.getitem', (
|
||||
x.name for x in model.layers))
|
||||
# TODO(b/161925288): Fix the dispatch triggering then uncomment:
|
||||
# self.assertNotIn('tf.strided_slice', (
|
||||
# x.name for x in model.layers))
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
# TODO(b/161925288): Fix the bug then uncomment:
|
||||
# # Make sure it can be successfully saved and loaded
|
||||
# config = model.get_config()
|
||||
# model = keras.Model.from_config(config)
|
||||
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
def test_getitem_index_real_tensor(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
x = math_ops.range(10.0)
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = x[slice_stop[0]]
|
||||
model = keras.Model(
|
||||
inputs=slice_stop,
|
||||
outputs=out)
|
||||
model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
batch_size = 7
|
||||
index = 6
|
||||
args = constant_op.constant(index, shape=(batch_size,))
|
||||
expected = x[index]
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
self.assertIn('tf.__operators__.getitem', (
|
||||
x.name for x in model.layers))
|
||||
# TODO(b/161925288): Fix the bug then uncomment:
|
||||
# self.assertNotIn('tf.strided_slice', (
|
||||
# x.name for x in model.layers))
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
# TODO(b/161925288): Fix the bug then uncomment:
|
||||
# # Make sure it can be successfully saved and loaded
|
||||
# config = model.get_config()
|
||||
# model = keras.Model.from_config(config)
|
||||
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
def test_getitem_slice_with_stop_only(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[:slice_stop[0]]
|
||||
model = keras.Model(
|
||||
inputs=[inp, slice_stop],
|
||||
outputs=out)
|
||||
model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
batch_size = 7
|
||||
stop = 6
|
||||
x = array_ops.stack([
|
||||
math_ops.range(8) for _ in range(batch_size)])
|
||||
args = [x, constant_op.constant(stop, shape=(batch_size,))]
|
||||
expected = x[:stop]
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
self.assertIn('tf.__operators__.getitem', (
|
||||
x.name for x in model.layers))
|
||||
self.assertNotIn('tf.strided_slice', (
|
||||
x.name for x in model.layers))
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
# Make sure it can be successfully saved and loaded
|
||||
config = model.get_config()
|
||||
model = keras.Model.from_config(config)
|
||||
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
def test_getitem_slice_with_stop_and_ellipsis_only(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[..., :slice_stop[0]]
|
||||
model = keras.Model(
|
||||
inputs=[inp, slice_stop],
|
||||
outputs=out)
|
||||
model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
batch_size = 7
|
||||
stop = 6
|
||||
x = array_ops.stack([
|
||||
math_ops.range(8) for _ in range(batch_size)])
|
||||
args = [x, constant_op.constant(stop, shape=(batch_size,))]
|
||||
expected = array_ops.stack([
|
||||
math_ops.range(8)[:stop] for _ in range(batch_size)])
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
self.assertIn('tf.__operators__.getitem', (
|
||||
x.name for x in model.layers))
|
||||
self.assertNotIn('tf.strided_slice', (
|
||||
x.name for x in model.layers))
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
# Make sure it can be successfully saved and loaded
|
||||
config = model.get_config()
|
||||
model = keras.Model.from_config(config)
|
||||
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
def test_getitem_complex_slicing(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest('Complex slicing like this fails in v1')
|
||||
inp = keras.Input(shape=(4, 3, 8))
|
||||
first_dim = keras.Input(shape=(), dtype='int32')
|
||||
slice_start = keras.Input(shape=(), dtype='int32')
|
||||
slice_stop = keras.Input(shape=(), dtype='int32')
|
||||
slice_stride = keras.Input(shape=(), dtype='int32')
|
||||
|
||||
out = inp[..., first_dim[0], slice_start[0]:slice_stop[0]:slice_stride[0]]
|
||||
model = keras.Model(
|
||||
inputs=[inp, first_dim, slice_start, slice_stop, slice_stride],
|
||||
outputs=out)
|
||||
model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
batch_size = 7
|
||||
start = 1
|
||||
stop = 6
|
||||
step = 2
|
||||
x = array_ops.stack([array_ops.stack([array_ops.stack([
|
||||
math_ops.range(8)
|
||||
for _ in range(3)]) for _ in range(4)]) for _ in range(batch_size)])
|
||||
args = [x,
|
||||
constant_op.constant(0, shape=(batch_size,)),
|
||||
constant_op.constant(start, shape=(batch_size,)),
|
||||
constant_op.constant(stop, shape=(batch_size,)),
|
||||
constant_op.constant(step, shape=(batch_size,))]
|
||||
# Slice the innermost dim. only grab one index from the second-to-innermost
|
||||
# dim, removing that dim from the shape.
|
||||
expected = array_ops.stack([array_ops.stack([
|
||||
math_ops.range(8)[start:stop:step]
|
||||
for _ in range(4)]) for _ in range(batch_size)])
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
self.assertIn('tf.__operators__.getitem', (
|
||||
x.name for x in model.layers))
|
||||
self.assertNotIn('tf.strided_slice', (
|
||||
x.name for x in model.layers))
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
# Make sure it can be successfully saved and loaded
|
||||
config = model.get_config()
|
||||
model = keras.Model.from_config(config)
|
||||
|
||||
self.assertAllEqual(model(args), expected)
|
||||
self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
|
||||
|
||||
def test_numerical_correctness_simple(self):
|
||||
x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
|
||||
inputs = keras.Input(shape=(4,))
|
||||
|
@ -70,11 +70,14 @@ def decode(json_string):
|
||||
|
||||
|
||||
def _decode_helper(obj):
|
||||
"""A decoding helper that is TF-object aware."""
|
||||
if isinstance(obj, dict) and 'class_name' in obj:
|
||||
if obj['class_name'] == 'TensorShape':
|
||||
return tensor_shape.TensorShape(obj['items'])
|
||||
elif obj['class_name'] == '__tuple__':
|
||||
return tuple(_decode_helper(i) for i in obj['items'])
|
||||
elif obj['class_name'] == '__ellipsis__':
|
||||
return Ellipsis
|
||||
return obj
|
||||
|
||||
|
||||
@ -122,6 +125,9 @@ def get_json_type(obj):
|
||||
if isinstance(obj, collections_abc.Mapping):
|
||||
return dict(obj)
|
||||
|
||||
if obj is Ellipsis:
|
||||
return {'class_name': '__ellipsis__'}
|
||||
|
||||
if isinstance(obj, wrapt.ObjectProxy):
|
||||
return obj.__wrapped__
|
||||
|
||||
|
@ -955,6 +955,8 @@ def _slice_helper(tensor, slice_spec, var=None):
|
||||
TypeError: If the slice indices aren't int, slice, ellipsis,
|
||||
tf.newaxis or scalar int32/int64 tensors.
|
||||
"""
|
||||
tensor = ops.convert_to_tensor(tensor)
|
||||
|
||||
if isinstance(slice_spec, bool) or \
|
||||
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
|
||||
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.proto_ops import decode_proto
|
||||
@ -28,6 +29,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -68,10 +70,38 @@ class TensorTracer(object):
|
||||
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
|
||||
return "{}({})".format(self.name, ", ".join(args))
|
||||
|
||||
@classmethod
|
||||
def _overload_all_operators(cls): # pylint: disable=invalid-name
|
||||
"""Register overloads for all operators."""
|
||||
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
|
||||
cls._overload_operator(operator)
|
||||
|
||||
@classmethod
|
||||
def _overload_operator(cls, operator): # pylint: disable=invalid-name
|
||||
"""Overload an operator with the same overloading as `ops.Tensor`."""
|
||||
tensor_oper = getattr(ops.Tensor, operator)
|
||||
|
||||
# Compatibility with Python 2:
|
||||
# Python 2 unbound methods have type checks for the first arg,
|
||||
# so we need to extract the underlying function
|
||||
tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)
|
||||
setattr(cls, operator, tensor_oper)
|
||||
|
||||
TensorTracer._overload_all_operators() # pylint: disable=protected-access
|
||||
|
||||
|
||||
class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||
"""Global op dispatcher for TensorTracer."""
|
||||
|
||||
def _flatten_with_slice_flattening(self, x):
|
||||
flat = []
|
||||
for val in nest.flatten(x):
|
||||
if isinstance(val, slice):
|
||||
flat.extend((val.start, val.stop, val.step))
|
||||
else:
|
||||
flat.append(val)
|
||||
return flat
|
||||
|
||||
def handle(self, op, args, kwargs):
|
||||
# Dispatcher only applies if at least one arg is a TensorTracer.
|
||||
if not (any(self.is_tensor_tracer_arg(x) for x in args) or
|
||||
@ -82,11 +112,8 @@ class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||
return TensorTracer(symbol_name, args, kwargs)
|
||||
|
||||
def is_tensor_tracer_arg(self, value):
|
||||
if isinstance(value, TensorTracer):
|
||||
return True
|
||||
if isinstance(value, (list, tuple)):
|
||||
if any(isinstance(x, TensorTracer) for x in value):
|
||||
return True
|
||||
return any(isinstance(x, TensorTracer) for x in
|
||||
self._flatten_with_slice_flattening(value))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -214,5 +241,46 @@ class DispatchTest(test_util.TensorFlowTestCase):
|
||||
# Clean up.
|
||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||
|
||||
def testGlobalDispatcherGetItem(self):
|
||||
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
|
||||
try:
|
||||
TensorTracerOpDispatcher().register()
|
||||
|
||||
x = TensorTracer("x")
|
||||
trace = x[0]
|
||||
self.assertEqual(
|
||||
str(trace),
|
||||
"__operators__.getitem(x, 0)")
|
||||
|
||||
x = TensorTracer("x")
|
||||
y = TensorTracer("y")
|
||||
trace = x[y]
|
||||
self.assertEqual(
|
||||
str(trace),
|
||||
"__operators__.getitem(x, y)")
|
||||
|
||||
x = TensorTracer("x")
|
||||
y = TensorTracer("y")
|
||||
trace = x[:y] # pylint: disable=invalid-slice-index
|
||||
self.assertEqual(
|
||||
str(trace),
|
||||
"__operators__.getitem(x, slice(None, y, None))")
|
||||
|
||||
x = array_ops.ones(shape=(3, 3))
|
||||
y = TensorTracer("y")
|
||||
trace = x[y]
|
||||
self.assertEqual(
|
||||
str(trace),
|
||||
"__operators__.getitem(%s, y)" % x)
|
||||
|
||||
trace = x[:y] # pylint: disable=invalid-slice-index
|
||||
self.assertEqual(
|
||||
str(trace),
|
||||
"__operators__.getitem(%s, slice(None, y, None))" % x)
|
||||
|
||||
finally:
|
||||
# Clean up.
|
||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -70,6 +70,9 @@ def get_json_type(obj):
|
||||
if isinstance(obj, collections_abc.Mapping):
|
||||
return dict(obj)
|
||||
|
||||
if obj is Ellipsis:
|
||||
return {'class_name': '__ellipsis__'}
|
||||
|
||||
if isinstance(obj, wrapt.ObjectProxy):
|
||||
return obj.__wrapped__
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user