This is important because `tensor.__getitem__` does some input arg manipulation before getting to the `tf.strided_slice`. So, when we try to run the traced code using the args provided to `strided_slice` (e.g. for KerasTensors), we lose information about constants that TPUs need to compile graphs involving shape manipulation. Tracing `__getitem__` and its input args directly does not seem to run into this problem. (Note: this TPU situation is separate from the shape value inferring we do in KerasTensors during Functional API construction/tracing time. This happens at model run-time when running the already-traced code) To get this all to work correctly in practice when dispatching KerasTensors + serializing/deserializing Keras models, this CL also has to: * Add special KerasTensor dispatchers for APIs that may take `slices` as inputs, to make sure they can trigger dispatching & serialize/deserialize correctly. This specialized dispatcher makes sure to unpack any `slices` in the args/kwargs into a namedtuple, before passing it to a specialized Keras TFOpLambda subclass that re-packs any slices. * Add serialization/deserialization support for `ellipsis` objects in Keras ------------------------ Other considered alternatives to get the dispatching/serialization to work correctly for KerasTensors: * add flatten/pack support for slices to `tf.nest`/`tree`. This can be revisited in the future (especially re: dispatchv2), but tree is critical path code and it's not obvious if we should always be flattening/packing slices or not. * Make the dispatched __operators__.getitem method expect slices to have already been unwrapped, and add a step to the __getitem__ overriding that unwraps the slices. This would be somewhat clunky in practice because there are other TF apis that take `slice`s in their args as well, and it might be surprising to dispatch users that the __operators__.getitem dispatch doesn't actually match the standard __getitem__ api. Likewise it's unclear what the performance implication of doing extra packing/unpacking even when not dispatching would be. PiperOrigin-RevId: 322655930 Change-Id: I35417577199393c016f753be685bf2926d62e753
80 lines
2.2 KiB
Python
80 lines
2.2 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities for serializing Python objects."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import wrapt
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.util.compat import collections_abc
|
|
|
|
|
|
def get_json_type(obj):
|
|
"""Serializes any object to a JSON-serializable structure.
|
|
|
|
Arguments:
|
|
obj: the object to serialize
|
|
|
|
Returns:
|
|
JSON-serializable structure representing `obj`.
|
|
|
|
Raises:
|
|
TypeError: if `obj` cannot be serialized.
|
|
"""
|
|
# if obj is a serializable Keras class instance
|
|
# e.g. optimizer, layer
|
|
if hasattr(obj, 'get_config'):
|
|
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
|
|
|
|
# if obj is any numpy type
|
|
if type(obj).__module__ == np.__name__:
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
else:
|
|
return obj.item()
|
|
|
|
# misc functions (e.g. loss function)
|
|
if callable(obj):
|
|
return obj.__name__
|
|
|
|
# if obj is a python 'type'
|
|
if type(obj).__name__ == type.__name__:
|
|
return obj.__name__
|
|
|
|
if isinstance(obj, tensor_shape.Dimension):
|
|
return obj.value
|
|
|
|
if isinstance(obj, tensor_shape.TensorShape):
|
|
return obj.as_list()
|
|
|
|
if isinstance(obj, dtypes.DType):
|
|
return obj.name
|
|
|
|
if isinstance(obj, collections_abc.Mapping):
|
|
return dict(obj)
|
|
|
|
if obj is Ellipsis:
|
|
return {'class_name': '__ellipsis__'}
|
|
|
|
if isinstance(obj, wrapt.ObjectProxy):
|
|
return obj.__wrapped__
|
|
|
|
raise TypeError('Not JSON Serializable:', obj)
|