STT-tensorflow/tensorflow/python/util/serialization.py
Tomer Kaftan e3be70aa9d Add a convert_to_tensor to the start of Tensor.__getitem__ (_slice_helper) to make sure it dispatches directly, rather than letting the nested tf.strided_slice trigger dispatching.
This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem.

(Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code)

To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to:

* Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices.

* Add serialization/deserialization support for `ellipsis` objects in Keras

------------------------
Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors:

* add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not.

* Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be.

PiperOrigin-RevId: 322655930
Change-Id: I35417577199393c016f753be685bf2926d62e753
2020-07-22 14:51:41 -07:00

80 lines
2.2 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for serializing Python objects."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import wrapt
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util.compat import collections_abc
def get_json_type(obj):
"""Serializes any object to a JSON-serializable structure.
Arguments:
obj: the object to serialize
Returns:
JSON-serializable structure representing `obj`.
Raises:
TypeError: if `obj` cannot be serialized.
"""
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj.item()
# misc functions (e.g. loss function)
if callable(obj):
return obj.__name__
# if obj is a python 'type'
if type(obj).__name__ == type.__name__:
return obj.__name__
if isinstance(obj, tensor_shape.Dimension):
return obj.value
if isinstance(obj, tensor_shape.TensorShape):
return obj.as_list()
if isinstance(obj, dtypes.DType):
return obj.name
if isinstance(obj, collections_abc.Mapping):
return dict(obj)
if obj is Ellipsis:
return {'class_name': '__ellipsis__'}
if isinstance(obj, wrapt.ObjectProxy):
return obj.__wrapped__
raise TypeError('Not JSON Serializable:', obj)