Merge pull request #31511 from tensorflow/ggadde-cp3
Cherrypicks to fix the disconnected graph issue, and missing CUDA compute capabilities.
This commit is contained in:
commit
926e66c254
@ -225,7 +225,8 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
# configured improperly.
|
||||
constants[i] = op_input
|
||||
else:
|
||||
constants[i] = backend.function([], op_input)([])
|
||||
with ops.init_scope():
|
||||
constants[i] = backend.function([], op_input)([])
|
||||
processed_ops, created_layers = _create_keras_history_helper(
|
||||
layer_inputs, processed_ops, created_layers)
|
||||
name = op.name
|
||||
@ -239,7 +240,7 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
return processed_ops, created_layers
|
||||
|
||||
|
||||
def needs_keras_history(tensors):
|
||||
def needs_keras_history(tensors, ignore_call_context=False):
|
||||
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
|
||||
|
||||
This will never return True inside a sublayer, because sublayers
|
||||
@ -249,12 +250,18 @@ def needs_keras_history(tensors):
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary nested structure of Tensors.
|
||||
ignore_call_context: Whether to ignore the check of if currently
|
||||
outside of a `call` context. This is `True` when creating
|
||||
KerasHistory inside `Node`, where we always know that Tensors
|
||||
are being used with the Functional API.
|
||||
|
||||
Returns:
|
||||
Bool, whether at least one Tensor needs to be wrapped.
|
||||
"""
|
||||
input_tensors = nest.flatten(tensors)
|
||||
if call_context().in_call or all(
|
||||
if call_context().in_call and not ignore_call_context:
|
||||
return False
|
||||
if all(
|
||||
getattr(tensor, '_keras_history', None) is not None
|
||||
for tensor in input_tensors):
|
||||
# KerasHistory already set.
|
||||
|
@ -1547,7 +1547,7 @@ class Network(base_layer.Layer):
|
||||
def _get_min_depth(node):
|
||||
"""Gets the minimum depth at which node can be computed."""
|
||||
min_depth = 0
|
||||
for layer, node_id, _, _ in node.iterate_inbound():
|
||||
for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True):
|
||||
inbound_node = layer._inbound_nodes[node_id]
|
||||
if inbound_node in node_to_depth:
|
||||
min_depth = min(min_depth, node_to_depth[inbound_node])
|
||||
@ -1720,7 +1720,8 @@ def _map_graph_network(inputs, outputs):
|
||||
nodes_in_progress.add(node)
|
||||
|
||||
# Propagate to all previous tensors connected to this node.
|
||||
for layer, node_index, tensor_index, tensor in node.iterate_inbound():
|
||||
for layer, node_index, tensor_index, tensor in node.iterate_inbound(
|
||||
include_arguments=True):
|
||||
build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
|
||||
tensor_index)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -111,6 +112,15 @@ class Node(object):
|
||||
# Optional keyword arguments to layer's `call`.
|
||||
self.arguments = arguments
|
||||
|
||||
# Create Keras History for any Keras Tensors in `arguments`.
|
||||
tensor_arguments = [
|
||||
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
|
||||
]
|
||||
for tensor_argument in tensor_arguments:
|
||||
if base_layer_utils.needs_keras_history(
|
||||
tensor_argument, ignore_call_context=True):
|
||||
base_layer_utils.create_keras_history(tensor_argument)
|
||||
|
||||
# Add nodes to all layers involved.
|
||||
for layer in nest.flatten(inbound_layers):
|
||||
if layer is not None:
|
||||
@ -121,15 +131,39 @@ class Node(object):
|
||||
# accessor here.
|
||||
outbound_layer.inbound_nodes.append(self)
|
||||
|
||||
def iterate_inbound(self):
|
||||
def iterate_inbound(self, include_arguments=False):
|
||||
"""Returns a list of tuples representing the inbound data.
|
||||
|
||||
Arguments:
|
||||
include_arguments: Whether to also iterate over any Keras Tensors
|
||||
passed as args, kwargs.
|
||||
|
||||
Returns:
|
||||
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
||||
"""
|
||||
return zip(
|
||||
nest.flatten(self.inbound_layers), nest.flatten(self.node_indices),
|
||||
nest.flatten(self.tensor_indices), nest.flatten(self.input_tensors))
|
||||
inputs_inbound = list(
|
||||
zip(
|
||||
nest.flatten(self.inbound_layers),
|
||||
nest.flatten(self.node_indices),
|
||||
nest.flatten(self.tensor_indices),
|
||||
nest.flatten(self.input_tensors)))
|
||||
|
||||
if include_arguments:
|
||||
keras_tensor_arguments = [
|
||||
kt for kt in nest.flatten(self.arguments)
|
||||
if hasattr(kt, '_keras_history')
|
||||
]
|
||||
|
||||
def _get_inbound(keras_tensor):
|
||||
kh = keras_tensor._keras_history
|
||||
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
|
||||
|
||||
arguments_inbound = nest.map_structure(_get_inbound,
|
||||
keras_tensor_arguments)
|
||||
|
||||
return inputs_inbound + arguments_inbound
|
||||
else:
|
||||
return inputs_inbound
|
||||
|
||||
def _get_all_node_dependencies(self):
|
||||
"""Returns all of the nodes this node immediately depends on."""
|
||||
|
@ -135,6 +135,19 @@ def _float64_op():
|
||||
return keras.Model(inputs, outputs)
|
||||
|
||||
|
||||
class MyAdd(keras.layers.Layer):
|
||||
|
||||
def call(self, x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
def _layer_with_tensor_arg():
|
||||
inputs = keras.Input(shape=(10,))
|
||||
x = inputs * 2
|
||||
outputs = MyAdd()(inputs, x)
|
||||
return keras.Model(inputs, outputs)
|
||||
|
||||
|
||||
class LayerWithLayer(keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
@ -191,6 +204,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
('_float64_op', _float64_op),
|
||||
('_inner_layer', _inner_layer),
|
||||
('_reuse_ancillary_layer', _reuse_ancillary_layer),
|
||||
('_layer_with_tensor_arg', _layer_with_tensor_arg),
|
||||
)
|
||||
def test_autolambda(self, model_fn):
|
||||
model = model_fn()
|
||||
@ -208,7 +222,11 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
model(np_inputs) # Test calling the model directly on inputs.
|
||||
|
||||
new_model = keras.Model.from_config(
|
||||
model.get_config(), custom_objects={'LayerWithLayer': LayerWithLayer})
|
||||
model.get_config(),
|
||||
custom_objects={
|
||||
'LayerWithLayer': LayerWithLayer,
|
||||
'MyAdd': MyAdd
|
||||
})
|
||||
new_model.compile(
|
||||
adam.Adam(0.001),
|
||||
'mse',
|
||||
|
@ -53,6 +53,11 @@ NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
|
||||
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
|
||||
NVCC_VERSION = '10.0'
|
||||
|
||||
# Environment variable for supported TF CUDA Compute Capabilities
|
||||
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
|
||||
CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
|
||||
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
|
||||
|
||||
def Log(s):
|
||||
print('gpus/crosstool: {0}'.format(s))
|
||||
|
||||
@ -202,7 +207,7 @@ def InvokeNvcc(argv, log=False):
|
||||
srcs = ' '.join(src_files)
|
||||
out = ' -o ' + out_file[0]
|
||||
|
||||
supported_cuda_compute_capabilities = [ "3.0", "6.0" ]
|
||||
supported_cuda_compute_capabilities = os.environ.get(CUDA_COMPUTE_ENV_VAR, DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
|
||||
nvccopts = '-D_FORCE_INLINES '
|
||||
for capability in supported_cuda_compute_capabilities:
|
||||
capability = capability.replace('.', '')
|
||||
|
@ -36,7 +36,14 @@ GCC_HOST_COMPILER_PATH = ('/dt7/usr/bin/gcc')
|
||||
NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
|
||||
NVCC_VERSION = '10.0'
|
||||
NVCC_TEMP_DIR = "C:\\Windows\\Temp\\nvcc_inter_files_tmp_dir"
|
||||
supported_cuda_compute_capabilities = [ "3.0", "6.0" ]
|
||||
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
|
||||
|
||||
# Taken from environment variable for supported TF CUDA Compute Capabilities
|
||||
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
|
||||
supported_cuda_compute_capabilities = os.environ.get(
|
||||
'TF_CUDA_COMPUTE_CAPABILITIES',
|
||||
DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
|
||||
|
||||
|
||||
def Log(s):
|
||||
print('gpus/crosstool: {0}'.format(s))
|
||||
|
Loading…
Reference in New Issue
Block a user