Add base_layer_util to frozen_keras.
PiperOrigin-RevId: 299861942 Change-Id: I0079431e8c8f8b0b32cb3ae25a052fca3ec97e96
This commit is contained in:
parent
482bb0c70d
commit
b8da76a5fe
@ -11,6 +11,7 @@ py_library(
|
||||
name = "legacy_base_layer",
|
||||
srcs = ["legacy_base_layer.py"],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:auto_control_deps",
|
||||
@ -37,7 +38,6 @@ py_library(
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:metrics",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/engine:base_layer_utils",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
"//tensorflow/python/keras/saving",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
@ -55,6 +55,29 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_layer_utils",
|
||||
srcs = ["base_layer_utils.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:control_flow_v2_func_graphs",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "legacy_base_layer_test",
|
||||
size = "medium",
|
||||
@ -73,3 +96,17 @@ tf_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "base_layer_utils_test",
|
||||
srcs = ["base_layer_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
781
tensorflow/python/frozen_keras/engine/base_layer_utils.py
Normal file
781
tensorflow/python/frozen_keras/engine/base_layer_utils.py
Normal file
@ -0,0 +1,781 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains private utilities used mainly by the base Layer class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.tracking import base as tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
_call_context = threading.local()
|
||||
|
||||
|
||||
def make_variable(name,
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
initializer=None,
|
||||
trainable=None,
|
||||
caching_device=None,
|
||||
validate_shape=True,
|
||||
constraint=None,
|
||||
use_resource=None,
|
||||
collections=None,
|
||||
synchronization=tf_variables.VariableSynchronization.AUTO,
|
||||
aggregation=tf_variables.VariableAggregation.NONE,
|
||||
partitioner=None): # pylint: disable=unused-argument
|
||||
"""Temporary util to create a variable (relies on `variable_scope.variable`).
|
||||
|
||||
Some reuse-related technicalities prevent us from using
|
||||
`variable_scope.get_variable()` directly, so we use a subcomponent
|
||||
that has fewer constraints (`variable_scope.variable()`).
|
||||
|
||||
In the longer term, it seems like a similar "default variable creator" method
|
||||
should exist in `Trackable` instead. When this happens, we can get
|
||||
rid of this temporary solution.
|
||||
|
||||
TODO(fchollet): remove this method when no longer needed.
|
||||
|
||||
Arguments:
|
||||
name: Variable name.
|
||||
shape: Variable shape.
|
||||
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
||||
initializer: Initializer instance (callable).
|
||||
trainable: Whether the variable should be part of the layer's
|
||||
"trainable_variables" (e.g. variables, biases)
|
||||
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
|
||||
Note, if the current variable scope is marked as non-trainable
|
||||
then this parameter is ignored and any added variables are also
|
||||
marked as non-trainable. `trainable` defaults to `True` unless
|
||||
`synchronization` is set to `ON_READ`.
|
||||
caching_device: Passed to `tf.Variable`.
|
||||
validate_shape: Passed to `tf.Variable`.
|
||||
constraint: Constraint instance (callable).
|
||||
use_resource: Whether to use a `ResourceVariable`.
|
||||
collections: List of graph collections keys. The new variable is added to
|
||||
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
|
||||
synchronization: Indicates when a distributed a variable will be
|
||||
aggregated. Accepted values are constants defined in the class
|
||||
`tf.VariableSynchronization`. By default the synchronization is set to
|
||||
`AUTO` and the current `DistributionStrategy` chooses
|
||||
when to synchronize. If `synchronization` is set to `ON_READ`,
|
||||
`trainable` must not be set to `True`.
|
||||
aggregation: Indicates how a distributed variable will be aggregated.
|
||||
Accepted values are constants defined in the class
|
||||
`tf.VariableAggregation`.
|
||||
partitioner: Not handled at this time.
|
||||
|
||||
Returns:
|
||||
Variable instance.
|
||||
"""
|
||||
initializing_from_value = False
|
||||
if initializer is not None and not callable(initializer):
|
||||
initializing_from_value = True
|
||||
|
||||
if initializing_from_value:
|
||||
init_val = initializer
|
||||
variable_dtype = None
|
||||
else:
|
||||
# Instantiate initializer if provided initializer is a type object.
|
||||
if isinstance(
|
||||
initializer,
|
||||
(type(init_ops.Initializer), type(init_ops_v2.Initializer))):
|
||||
initializer = initializer()
|
||||
init_val = lambda: initializer(shape, dtype=dtype)
|
||||
variable_dtype = dtype.base_dtype
|
||||
if use_resource is None:
|
||||
use_resource = True
|
||||
|
||||
# TODO(apassos,rohanj) figure out how to remove collections from here so we
|
||||
# can remove the V1.
|
||||
variable_shape = tensor_shape.TensorShape(shape)
|
||||
return tf_variables.VariableV1(
|
||||
initial_value=init_val,
|
||||
name=name,
|
||||
trainable=trainable,
|
||||
caching_device=caching_device,
|
||||
dtype=variable_dtype,
|
||||
validate_shape=validate_shape,
|
||||
constraint=constraint,
|
||||
use_resource=use_resource,
|
||||
collections=collections,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation,
|
||||
shape=variable_shape if variable_shape else None)
|
||||
|
||||
|
||||
def collect_previous_mask(input_tensors):
|
||||
"""Retrieves the output mask(s) of the previous node.
|
||||
|
||||
Arguments:
|
||||
input_tensors: An arbitrary structure of Tensors.
|
||||
|
||||
Returns:
|
||||
A mask tensor or list of mask tensors.
|
||||
"""
|
||||
|
||||
def _collect_previous_mask(x):
|
||||
return getattr(x, '_keras_mask', None)
|
||||
|
||||
return nest.map_structure(_collect_previous_mask, input_tensors)
|
||||
|
||||
|
||||
def have_all_keras_metadata(tensors):
|
||||
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
|
||||
|
||||
|
||||
def generate_placeholders_from_shape(shape):
|
||||
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
|
||||
|
||||
|
||||
def create_keras_history(tensors):
|
||||
"""Wraps TensorFlow Operations for compatibility with the Functional API.
|
||||
|
||||
This method checks to see if a Tensor in `tensors` is missing Keras metadata
|
||||
and has its origin in a Keras `Input` Layer. If so, this method will replace
|
||||
the raw TensorFlow Operations that created this tensor with
|
||||
`TensorFlowOpLayer` instances that create identical operations.
|
||||
|
||||
Any Tensors not originating from a Keras `Input` Layer will be treated as
|
||||
constants when constructing `TensorFlowOpLayer` instances.
|
||||
|
||||
Arguments:
|
||||
tensors: A structure of Tensors, some of which come from raw TensorFlow
|
||||
operations and need to have Keras metadata assigned to them.
|
||||
|
||||
Returns:
|
||||
created_layers: List. The `TensorFlowOpLayer` instances created to wrap
|
||||
the raw Tensorflow operations.
|
||||
"""
|
||||
_, created_layers = _create_keras_history_helper(tensors, set(), [])
|
||||
return created_layers
|
||||
|
||||
|
||||
def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
"""Helper method for `create_keras_history`.
|
||||
|
||||
Arguments:
|
||||
tensors: A structure of Tensors for which to create Keras metadata.
|
||||
processed_ops: Set. TensorFlow operations that have already been wrapped in
|
||||
`TensorFlowOpLayer` instances.
|
||||
created_layers: List. The `TensorFlowOpLayer` instances created.
|
||||
|
||||
Returns:
|
||||
Tuple. First element is the updated set of TensorFlow Operations that
|
||||
have been wrapped in `TensorFlowOpLayer` instances. Second element is
|
||||
a list of the `TensorFlowOpLayer` instances created.
|
||||
"""
|
||||
# Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
|
||||
# Cannot be imported at top because of circular dependencies.
|
||||
# TODO(omalleyt): Resolve circular dependency.
|
||||
from tensorflow.python.frozen_keras.engine import legacy_base_layer as base_layer # pylint: disable=g-import-not-at-top
|
||||
tensor_list = nest.flatten(tensors)
|
||||
for tensor in tensor_list:
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
continue
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
if op not in processed_ops:
|
||||
if op.type.startswith('Sparse'):
|
||||
lambda_example = """
|
||||
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
|
||||
output = tf.keras.layers.Lambda(weights_mult)(input)
|
||||
"""
|
||||
raise ValueError(
|
||||
'Sparse ops are not supported with functional models with built-in '
|
||||
'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
|
||||
': \n{lambda_example}\n'.format(lambda_example=lambda_example))
|
||||
|
||||
# Recursively set `_keras_history`.
|
||||
op_inputs = list(op.inputs)
|
||||
constants = {}
|
||||
layer_inputs = []
|
||||
for i, op_input in enumerate(op_inputs):
|
||||
if uses_keras_history(op_input):
|
||||
layer_inputs.append(op_input)
|
||||
else:
|
||||
# Treat any value not originating from a `keras.Input` as
|
||||
# a constant. Variables cannot be supported.
|
||||
ds_with_session = (
|
||||
distribution_strategy_context.in_cross_replica_context() and
|
||||
not ops.executing_eagerly_outside_functions())
|
||||
using_xla = control_flow_util.GraphOrParentsInXlaContext(
|
||||
ops.get_default_graph())
|
||||
if ds_with_session or using_xla:
|
||||
# In Legacy Graph mode, evaluating here makes Session be
|
||||
# configured improperly. The downside of this is that saving
|
||||
# via `get_config` breaks, but SavedModel still works.
|
||||
constants[i] = op_input
|
||||
else:
|
||||
with ops.init_scope():
|
||||
constants[i] = backend.function([], op_input)([])
|
||||
layer_inputs = unnest_if_single_tensor(layer_inputs)
|
||||
processed_ops, created_layers = _create_keras_history_helper(
|
||||
layer_inputs, processed_ops, created_layers)
|
||||
name = op.name
|
||||
node_def = op.node_def.SerializeToString()
|
||||
op_layer = base_layer.TensorFlowOpLayer(
|
||||
node_def, constants=constants, name=name)
|
||||
created_layers.append(op_layer)
|
||||
op_layer._add_inbound_node( # pylint: disable=protected-access
|
||||
layer_inputs, op.outputs)
|
||||
processed_ops.update([op])
|
||||
return processed_ops, created_layers
|
||||
|
||||
|
||||
def unnest_if_single_tensor(input_tensors):
|
||||
# Preserve compatibility with older configs
|
||||
flat_input_tensors = nest.flatten(input_tensors)
|
||||
# If this is a single element but not a dict, unwrap. If this is a dict,
|
||||
# assume the first layer expects a dict (as is the case with a
|
||||
# DenseFeatures layer); pass through.
|
||||
if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
|
||||
input_tensors = flat_input_tensors[0]
|
||||
return input_tensors
|
||||
|
||||
|
||||
def needs_keras_history(tensors, ignore_call_context=False):
|
||||
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
|
||||
|
||||
This will never return True inside a sublayer, because sublayers
|
||||
do not need to create Keras History. Otherwise, this returns True
|
||||
if one or more of `tensors` originates from a `keras.Input` and
|
||||
does not have `_keras_history` set.
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary nested structure of Tensors.
|
||||
ignore_call_context: Whether to ignore the check of if currently
|
||||
outside of a `call` context. This is `True` when creating
|
||||
KerasHistory inside `Node`, where we always know that Tensors
|
||||
are being used with the Functional API.
|
||||
|
||||
Returns:
|
||||
Bool, whether at least one Tensor needs to be wrapped.
|
||||
"""
|
||||
input_tensors = nest.flatten(tensors)
|
||||
if call_context().in_call and not ignore_call_context:
|
||||
return False
|
||||
if all(
|
||||
getattr(tensor, '_keras_history', None) is not None
|
||||
for tensor in input_tensors):
|
||||
# KerasHistory already set.
|
||||
return False
|
||||
return uses_keras_history(tensors)
|
||||
|
||||
|
||||
def is_in_keras_graph():
|
||||
"""Returns if currently executing inside of a Keras graph."""
|
||||
return call_context().in_keras_graph
|
||||
|
||||
|
||||
def is_in_eager_or_tf_function():
|
||||
"""Returns if in eager mode or inside of a tf.function."""
|
||||
return context.executing_eagerly() or is_in_tf_function()
|
||||
|
||||
|
||||
def is_in_tf_function():
|
||||
"""Returns if inside of a tf.function."""
|
||||
# Check if running in V1 graph mode.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return False
|
||||
if not ops.inside_function():
|
||||
return False
|
||||
# Check if inside Keras FuncGraph.
|
||||
if is_in_keras_graph():
|
||||
return False
|
||||
# Check for a v1 `wrap_function` FuncGraph.
|
||||
graph = ops.get_default_graph()
|
||||
if (getattr(graph, 'name', False) and
|
||||
graph.name.startswith('wrapped_function')):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def uses_keras_history(tensors):
|
||||
"""Check if at least one Tensor originates from a `keras.Input`.
|
||||
|
||||
This is `True` if at least one Tensor has its origin in a `keras.Input`.
|
||||
Any Tensor that originates from a `keras.Input` will have a dependency
|
||||
Tensor with a `_keras_history` attribute attached. Tensors that have
|
||||
already been checked to not originate from a `keras.Input`
|
||||
are marked as `_keras_history_checked`.
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary nested structure of Tensors.
|
||||
|
||||
Returns:
|
||||
Bool, whether at least one Tensor originates from a `keras.Input`.
|
||||
"""
|
||||
checked_tensors = set()
|
||||
tensors_to_check = nest.flatten(tensors)
|
||||
|
||||
while tensors_to_check:
|
||||
new_tensors_to_check = []
|
||||
for tensor in tensors_to_check:
|
||||
if id(tensor) in checked_tensors:
|
||||
continue
|
||||
|
||||
checked_tensors.add(id(tensor))
|
||||
|
||||
if getattr(tensor, '_keras_history_checked', None) is not None:
|
||||
continue
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
return True
|
||||
|
||||
try:
|
||||
new_tensors_to_check.extend(tensor.op.inputs)
|
||||
except AttributeError:
|
||||
# In case `tensor` is a Variable created in an Eager context.
|
||||
pass
|
||||
|
||||
tensors_to_check = new_tensors_to_check
|
||||
|
||||
# Mark that these Tensors have been checked once for `_keras_history`,
|
||||
# and should not be checked again for performance reasons.
|
||||
mark_checked(tensors)
|
||||
return False
|
||||
|
||||
|
||||
def mark_checked(tensors):
|
||||
"""Marks that these Tensors should not be tracked.
|
||||
|
||||
This prevents Layers from attempting to create TensorFlowOpLayers
|
||||
for these Tensors.
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary structure of Tensors.
|
||||
"""
|
||||
|
||||
def _mark_checked(tensor):
|
||||
tensor._keras_history_checked = True # pylint: disable=protected-access
|
||||
|
||||
nest.map_structure(_mark_checked, tensors)
|
||||
|
||||
|
||||
def call_context():
|
||||
"""Returns currently active `CallContext`."""
|
||||
if getattr(_call_context, 'call_context', None) is None:
|
||||
_call_context.call_context = CallContext()
|
||||
return _call_context.call_context
|
||||
|
||||
|
||||
class CallContext(object):
|
||||
"""Keeps track of properties currently inside a Layer/Model's `call`.
|
||||
|
||||
Attributes:
|
||||
layer: The `Layer` whose `call` is currently active.
|
||||
inputs: The inputs to the currently active `Layer`.
|
||||
frozen: Whether currently executing inside a `Layer` with `trainable` set to
|
||||
`False`.
|
||||
in_call: Whether currently inside the `call` of a Layer.
|
||||
training: Whether currently executing in training or inference mode.
|
||||
in_keras_graph: Whether executing inside the Keras Graph.
|
||||
saving: Whether currently saving to SavedModel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.layer = None
|
||||
self.inputs = None
|
||||
self.frozen = False
|
||||
self.in_call = False
|
||||
self.training = None
|
||||
self._in_keras_graph = False
|
||||
self.saving = False
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def enter(self, layer, inputs, build_graph, training, saving=None):
|
||||
"""Push a Layer and its inputs and state onto the current call context."""
|
||||
prev_layer = self.layer
|
||||
prev_inputs = self.inputs
|
||||
prev_frozen = self.frozen
|
||||
prev_in_call = self.in_call
|
||||
prev_training = self.training
|
||||
prev_in_keras_graph = self._in_keras_graph
|
||||
prev_saving = self.saving
|
||||
|
||||
self.layer = layer
|
||||
self.inputs = inputs
|
||||
self.frozen = self.frozen or not layer.trainable
|
||||
self.in_call = True
|
||||
self.training = training
|
||||
self._in_keras_graph = (
|
||||
self._in_keras_graph or
|
||||
(build_graph and
|
||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
||||
self.saving = prev_saving if saving is None else saving
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.layer = prev_layer
|
||||
self.inputs = prev_inputs
|
||||
self.frozen = prev_frozen
|
||||
self.in_call = prev_in_call
|
||||
self.training = prev_training
|
||||
self._in_keras_graph = prev_in_keras_graph
|
||||
self.saving = prev_saving
|
||||
|
||||
@property
|
||||
def in_keras_graph(self):
|
||||
# Returns True even if in a subgraph of the Keras graph, such as those
|
||||
# created by control flow ops.
|
||||
if context.executing_eagerly():
|
||||
return False
|
||||
return (self._in_keras_graph or
|
||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
|
||||
|
||||
|
||||
def training_arg_passed_to_call(argspec, args, kwargs):
|
||||
"""Returns whether a user passed the `training` argument in `__call__`."""
|
||||
# `argspec.args` starts with ['self', 'inputs']
|
||||
full_args = dict(zip(argspec.args[2:], args))
|
||||
full_args.update(kwargs)
|
||||
return 'training' in full_args and full_args['training'] is not None
|
||||
|
||||
|
||||
def autocast_context_manager(dtype):
|
||||
"""Returns a context manager to autocast AutoCastVariables.
|
||||
|
||||
Under this context manager, AutoCastVariables will be casted to `dtype` if
|
||||
`dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
|
||||
|
||||
Args:
|
||||
dtype: The dtype to cast AutoCastVariables to, or None.
|
||||
|
||||
Returns:
|
||||
A context manager to automatically cast AutoCastVariables.
|
||||
"""
|
||||
if dtype and not dtypes.as_dtype(dtype).is_floating:
|
||||
dtype = None
|
||||
return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def is_subclassed(layer):
|
||||
"""Returns True if the object is a subclassed layer or subclassed model."""
|
||||
return (layer.__module__.find('keras.engine') == -1 and
|
||||
layer.__module__.find('keras.layers') == -1)
|
||||
|
||||
|
||||
def from_saved_model(layer):
|
||||
"""Returns whether the layer is loaded from a SavedModel."""
|
||||
return layer.__module__.find('keras.saving.saved_model') != -1
|
||||
|
||||
|
||||
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
|
||||
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
||||
|
||||
When one of the `add_*` method is called inside a V2 conditional branch,
|
||||
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
|
||||
We need to raise clear error messages in such cases.
|
||||
|
||||
Arguments:
|
||||
tensor: Tensor to check, or `False` if it is known that an error
|
||||
should be raised.
|
||||
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
||||
force_raise: If an error should be raised regardless of `tensor`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: In case of an out-of-graph tensor.
|
||||
"""
|
||||
if (force_raise or
|
||||
(ops.executing_eagerly_outside_functions() and
|
||||
hasattr(tensor, 'graph') and
|
||||
isinstance(tensor.graph,
|
||||
(control_flow_v2_func_graphs.CondBranchFuncGraph,
|
||||
control_flow_v2_func_graphs.WhileCondFuncGraph,
|
||||
control_flow_v2_func_graphs.WhileBodyFuncGraph)))):
|
||||
if method == 'activity_regularizer':
|
||||
bad_example = """
|
||||
class TestModel(tf.keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__(name='test_model')
|
||||
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training:
|
||||
return self.dense(x)
|
||||
else:
|
||||
return self.dense(x)
|
||||
"""
|
||||
correct_example = """
|
||||
class TestModel(tf.keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__(name='test_model')
|
||||
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
|
||||
|
||||
def call(self, x, training=None):
|
||||
return self.dense(x)
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using a layer with `activity_regularizer` in a control flow '
|
||||
'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
|
||||
'Please move your call to the layer with `activity_regularizer` out '
|
||||
'of the control flow branch, e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your outer model/layer dynamic'
|
||||
' (eager-only) by passing `dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you to implement static '
|
||||
'shape inference in the `compute_output_shape(input_shape)` '
|
||||
'method.'.format(
|
||||
bad_example=bad_example, correct_example=correct_example))
|
||||
|
||||
if method == 'add_metric':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
else:
|
||||
metric = 0.
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
elif method == 'add_loss':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
else:
|
||||
loss = 0.
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
else:
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.w.assign_add(1))
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
increment = 1
|
||||
else:
|
||||
increment = 0
|
||||
self.add_update(self.w.assign_add(increment))
|
||||
return inputs
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using the method `{method}` in a control flow branch '
|
||||
'in your layer, e.g.:\n{bad_example}\n'
|
||||
'This is not currently supported. '
|
||||
'Please move your call to {method} out of the control flow branch, '
|
||||
'e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your layer '
|
||||
'as dynamic (eager-only) by passing '
|
||||
'`dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you '
|
||||
'to implement static shape inference '
|
||||
'in the `compute_output_shape(input_shape)` method.'.format(
|
||||
method=method,
|
||||
bad_example=bad_example,
|
||||
correct_example=correct_example))
|
||||
|
||||
|
||||
def mark_as_return(outputs, acd):
|
||||
"""Marks `outputs` as the return values for automatic control deps."""
|
||||
|
||||
def _mark_as_return(tensor):
|
||||
"""Marks `tensor` as the return value for automatic control deps."""
|
||||
if not tensor_util.is_tensor(tensor):
|
||||
return tensor
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return_tensor = acd.mark_as_return(tensor)
|
||||
if getattr(tensor, '_keras_mask', None) is not None:
|
||||
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
|
||||
else:
|
||||
return_tensor._keras_mask = None
|
||||
|
||||
# Handle TensorFlow Probability attached metadata.
|
||||
# TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
|
||||
if getattr(tensor, '_tfp_distribution', None) is not None:
|
||||
return_tensor._tfp_distribution = tensor._tfp_distribution
|
||||
|
||||
return return_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return nest.map_structure(_mark_as_return, outputs)
|
||||
|
||||
|
||||
V2_DTYPE_BEHAVIOR = None
|
||||
|
||||
|
||||
# These two functions are not exported because we plan on removing them in the
|
||||
# future.
|
||||
def enable_v2_dtype_behavior():
|
||||
"""Enable the V2 dtype behavior for Keras layers.
|
||||
|
||||
By default, the V2 dtype behavior is enabled in TensorFlow 2.
|
||||
|
||||
When enabled, the dtype of Keras layers defaults to floatx (which is typically
|
||||
float32) instead of None. In addition, layers will automatically cast
|
||||
floating-point inputs to the layer's dtype.
|
||||
|
||||
For example, once enabled, the following block will run a Conv2D layer
|
||||
in float32:
|
||||
|
||||
```python
|
||||
x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
print(layer.dtype) # Float32 when enabled. None when disabled.
|
||||
# When enabled, will cast inputs to the layer's dtype, which is float32. When
|
||||
# disabled, will do no casting, so the layer is done in float64.
|
||||
y = layer(x)
|
||||
```
|
||||
|
||||
A layer author can opt-out their layer from the automatic input casting by
|
||||
passing `autocast=False` to the base Layer's constructor. This disables the
|
||||
autocasting part of the V2 behavior for that layer, but not the defaulting to
|
||||
floatx part of the V2 behavior.
|
||||
|
||||
When a global `tf.keras.mixed_precision.experimental.Policy` is set, the
|
||||
layer's dtype will default to the global policy instead of floatx. Layers
|
||||
will automatically cast inputs to the policy's compute_dtype.
|
||||
"""
|
||||
global V2_DTYPE_BEHAVIOR
|
||||
V2_DTYPE_BEHAVIOR = True
|
||||
|
||||
|
||||
def disable_v2_dtype_behavior():
|
||||
"""Disables the V2 dtype behavior for Keras layers.
|
||||
|
||||
See `enable_v2_dtype_behavior`.
|
||||
|
||||
This function will be removed in the future.
|
||||
"""
|
||||
global V2_DTYPE_BEHAVIOR
|
||||
V2_DTYPE_BEHAVIOR = False
|
||||
|
||||
|
||||
def v2_dtype_behavior_enabled():
|
||||
"""Returns True if the V2 dtype behavior is enabled."""
|
||||
if V2_DTYPE_BEHAVIOR is None:
|
||||
return tf2.enabled()
|
||||
return V2_DTYPE_BEHAVIOR
|
||||
|
||||
|
||||
class TrackableWeightHandler(object):
|
||||
"""Keras wrapper for handling tracking.Trackable object saving and restoring.
|
||||
|
||||
This class handles Trackables in both V1 and V2 modes, ensuring that they can
|
||||
be saved and restored with the correct data and without adding additional ops
|
||||
on every save.
|
||||
|
||||
Attributes:
|
||||
trackable: The trackable to wrap.
|
||||
num_tensors: The number of tensors that this trackable requires for saving.
|
||||
"""
|
||||
|
||||
def __init__(self, trackable):
|
||||
if not isinstance(trackable, tracking.Trackable):
|
||||
raise ValueError('%s is not a Trackable object.' % (trackable,))
|
||||
self._trackable = trackable
|
||||
|
||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||
if len(saveables) != 1:
|
||||
raise ValueError('Only Trackables with one Saveable are supported.')
|
||||
saveable = list(saveables)[0]
|
||||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# If we're in eager mode, we need to defer calling the Trackable's
|
||||
# saveable() callable until data export time.
|
||||
# However, it is safe to call the saveable as many times as we want, so
|
||||
# we will call it now to figure out how many tensors this Trackable will
|
||||
# produce.
|
||||
self._saveable = saveable
|
||||
self._num_tensors = len(self._saveable().specs)
|
||||
self._setter = lambda weights: self._saveable().restore(weights, None)
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
|
||||
else:
|
||||
# If we're in Graph mode, we need to evaluate the Saveable only once and
|
||||
# cache the resulting restore graph. Failing to do this will result in
|
||||
# new assignment ops being added to the graph each time set_weights() is
|
||||
# called.
|
||||
self._placeholder_tensors = []
|
||||
self._saveable = saveable()
|
||||
self._num_tensors = len(self._saveable.specs)
|
||||
for spec in self._saveable.specs:
|
||||
tensor = spec.tensor
|
||||
self._placeholder_tensors.append(
|
||||
array_ops.placeholder(tensor.dtype, tensor.shape))
|
||||
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
|
||||
self._setter = self._set_weights_v1
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
||||
|
||||
@property
|
||||
def num_tensors(self):
|
||||
return self._num_tensors
|
||||
|
||||
def set_weights(self, weights):
|
||||
if len(weights) != self._num_tensors:
|
||||
raise ValueError(
|
||||
('Weight handler for trackable %s received the wrong number of ' +
|
||||
'weights: expected %s, got %s.') %
|
||||
(self._trackable, self._num_tensors, len(weights)))
|
||||
self._setter(weights)
|
||||
|
||||
def get_tensors(self):
|
||||
return self._getter()
|
||||
|
||||
def _set_weights_v1(self, weights):
|
||||
feed_dict = {}
|
||||
for idx, tensor in enumerate(weights):
|
||||
feed_dict[self._placeholder_tensors[idx]] = tensor
|
||||
backend.get_session().run(self._assign_op, feed_dict)
|
||||
|
||||
|
||||
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
|
||||
# from SavedModel, only the top-level layer will have losses. This causes issues
|
||||
# in eager mode because the child layers may have graph losses
|
||||
# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
|
||||
# whenever eager losses are added to one layer, add eager losses to all
|
||||
# child layers. This causes `.losses` to only return eager losses.
|
||||
REVIVED_LOSS_PLACEHOLDER = (
|
||||
'This layer\'s losses have been added to the parent layer.')
|
@ -0,0 +1,71 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TrackableWeightHandlerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def get_table_handler(self):
|
||||
# Note: There is some repetition in these tests' setup. However, Tensorflow
|
||||
# does not play nicely with a separate setUp() call (causing errors related
|
||||
# to graph building), so we have to use a called setup instead of a setUp()
|
||||
# call.
|
||||
table = lookup_ops.MutableHashTable(
|
||||
key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0)
|
||||
return base_layer_utils.TrackableWeightHandler(table)
|
||||
|
||||
def test_get_num_tensors(self):
|
||||
table_handler = self.get_table_handler()
|
||||
self.assertEqual(2, table_handler.num_tensors)
|
||||
|
||||
def test_get_and_set_weights(self):
|
||||
table_handler = self.get_table_handler()
|
||||
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
weights = backend.batch_get_value(table_handler.get_tensors())
|
||||
weight_data = {key: value for key, value in zip(weights[0], weights[1])}
|
||||
self.assertDictEqual(table_data, weight_data)
|
||||
|
||||
def test_get_and_set_weights_does_not_add_ops(self):
|
||||
table_handler = self.get_table_handler()
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
backend.get_session().graph.finalize()
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user