Merge pull request #31511 from tensorflow/ggadde-cp3
Cherrypicks to fix the disconnected graph issue, and missing CUDA compute capabilities.
This commit is contained in:
commit
926e66c254
@ -225,7 +225,8 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
|||||||
# configured improperly.
|
# configured improperly.
|
||||||
constants[i] = op_input
|
constants[i] = op_input
|
||||||
else:
|
else:
|
||||||
constants[i] = backend.function([], op_input)([])
|
with ops.init_scope():
|
||||||
|
constants[i] = backend.function([], op_input)([])
|
||||||
processed_ops, created_layers = _create_keras_history_helper(
|
processed_ops, created_layers = _create_keras_history_helper(
|
||||||
layer_inputs, processed_ops, created_layers)
|
layer_inputs, processed_ops, created_layers)
|
||||||
name = op.name
|
name = op.name
|
||||||
@ -239,7 +240,7 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
|||||||
return processed_ops, created_layers
|
return processed_ops, created_layers
|
||||||
|
|
||||||
|
|
||||||
def needs_keras_history(tensors):
|
def needs_keras_history(tensors, ignore_call_context=False):
|
||||||
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
|
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
|
||||||
|
|
||||||
This will never return True inside a sublayer, because sublayers
|
This will never return True inside a sublayer, because sublayers
|
||||||
@ -249,12 +250,18 @@ def needs_keras_history(tensors):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
tensors: An arbitrary nested structure of Tensors.
|
tensors: An arbitrary nested structure of Tensors.
|
||||||
|
ignore_call_context: Whether to ignore the check of if currently
|
||||||
|
outside of a `call` context. This is `True` when creating
|
||||||
|
KerasHistory inside `Node`, where we always know that Tensors
|
||||||
|
are being used with the Functional API.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Bool, whether at least one Tensor needs to be wrapped.
|
Bool, whether at least one Tensor needs to be wrapped.
|
||||||
"""
|
"""
|
||||||
input_tensors = nest.flatten(tensors)
|
input_tensors = nest.flatten(tensors)
|
||||||
if call_context().in_call or all(
|
if call_context().in_call and not ignore_call_context:
|
||||||
|
return False
|
||||||
|
if all(
|
||||||
getattr(tensor, '_keras_history', None) is not None
|
getattr(tensor, '_keras_history', None) is not None
|
||||||
for tensor in input_tensors):
|
for tensor in input_tensors):
|
||||||
# KerasHistory already set.
|
# KerasHistory already set.
|
||||||
|
@ -1547,7 +1547,7 @@ class Network(base_layer.Layer):
|
|||||||
def _get_min_depth(node):
|
def _get_min_depth(node):
|
||||||
"""Gets the minimum depth at which node can be computed."""
|
"""Gets the minimum depth at which node can be computed."""
|
||||||
min_depth = 0
|
min_depth = 0
|
||||||
for layer, node_id, _, _ in node.iterate_inbound():
|
for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True):
|
||||||
inbound_node = layer._inbound_nodes[node_id]
|
inbound_node = layer._inbound_nodes[node_id]
|
||||||
if inbound_node in node_to_depth:
|
if inbound_node in node_to_depth:
|
||||||
min_depth = min(min_depth, node_to_depth[inbound_node])
|
min_depth = min(min_depth, node_to_depth[inbound_node])
|
||||||
@ -1720,7 +1720,8 @@ def _map_graph_network(inputs, outputs):
|
|||||||
nodes_in_progress.add(node)
|
nodes_in_progress.add(node)
|
||||||
|
|
||||||
# Propagate to all previous tensors connected to this node.
|
# Propagate to all previous tensors connected to this node.
|
||||||
for layer, node_index, tensor_index, tensor in node.iterate_inbound():
|
for layer, node_index, tensor_index, tensor in node.iterate_inbound(
|
||||||
|
include_arguments=True):
|
||||||
build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
|
build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
|
||||||
tensor_index)
|
tensor_index)
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -111,6 +112,15 @@ class Node(object):
|
|||||||
# Optional keyword arguments to layer's `call`.
|
# Optional keyword arguments to layer's `call`.
|
||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
|
|
||||||
|
# Create Keras History for any Keras Tensors in `arguments`.
|
||||||
|
tensor_arguments = [
|
||||||
|
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
|
||||||
|
]
|
||||||
|
for tensor_argument in tensor_arguments:
|
||||||
|
if base_layer_utils.needs_keras_history(
|
||||||
|
tensor_argument, ignore_call_context=True):
|
||||||
|
base_layer_utils.create_keras_history(tensor_argument)
|
||||||
|
|
||||||
# Add nodes to all layers involved.
|
# Add nodes to all layers involved.
|
||||||
for layer in nest.flatten(inbound_layers):
|
for layer in nest.flatten(inbound_layers):
|
||||||
if layer is not None:
|
if layer is not None:
|
||||||
@ -121,15 +131,39 @@ class Node(object):
|
|||||||
# accessor here.
|
# accessor here.
|
||||||
outbound_layer.inbound_nodes.append(self)
|
outbound_layer.inbound_nodes.append(self)
|
||||||
|
|
||||||
def iterate_inbound(self):
|
def iterate_inbound(self, include_arguments=False):
|
||||||
"""Returns a list of tuples representing the inbound data.
|
"""Returns a list of tuples representing the inbound data.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
include_arguments: Whether to also iterate over any Keras Tensors
|
||||||
|
passed as args, kwargs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
||||||
"""
|
"""
|
||||||
return zip(
|
inputs_inbound = list(
|
||||||
nest.flatten(self.inbound_layers), nest.flatten(self.node_indices),
|
zip(
|
||||||
nest.flatten(self.tensor_indices), nest.flatten(self.input_tensors))
|
nest.flatten(self.inbound_layers),
|
||||||
|
nest.flatten(self.node_indices),
|
||||||
|
nest.flatten(self.tensor_indices),
|
||||||
|
nest.flatten(self.input_tensors)))
|
||||||
|
|
||||||
|
if include_arguments:
|
||||||
|
keras_tensor_arguments = [
|
||||||
|
kt for kt in nest.flatten(self.arguments)
|
||||||
|
if hasattr(kt, '_keras_history')
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_inbound(keras_tensor):
|
||||||
|
kh = keras_tensor._keras_history
|
||||||
|
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
|
||||||
|
|
||||||
|
arguments_inbound = nest.map_structure(_get_inbound,
|
||||||
|
keras_tensor_arguments)
|
||||||
|
|
||||||
|
return inputs_inbound + arguments_inbound
|
||||||
|
else:
|
||||||
|
return inputs_inbound
|
||||||
|
|
||||||
def _get_all_node_dependencies(self):
|
def _get_all_node_dependencies(self):
|
||||||
"""Returns all of the nodes this node immediately depends on."""
|
"""Returns all of the nodes this node immediately depends on."""
|
||||||
|
@ -135,6 +135,19 @@ def _float64_op():
|
|||||||
return keras.Model(inputs, outputs)
|
return keras.Model(inputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class MyAdd(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
def _layer_with_tensor_arg():
|
||||||
|
inputs = keras.Input(shape=(10,))
|
||||||
|
x = inputs * 2
|
||||||
|
outputs = MyAdd()(inputs, x)
|
||||||
|
return keras.Model(inputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
class LayerWithLayer(keras.layers.Layer):
|
class LayerWithLayer(keras.layers.Layer):
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
@ -191,6 +204,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
('_float64_op', _float64_op),
|
('_float64_op', _float64_op),
|
||||||
('_inner_layer', _inner_layer),
|
('_inner_layer', _inner_layer),
|
||||||
('_reuse_ancillary_layer', _reuse_ancillary_layer),
|
('_reuse_ancillary_layer', _reuse_ancillary_layer),
|
||||||
|
('_layer_with_tensor_arg', _layer_with_tensor_arg),
|
||||||
)
|
)
|
||||||
def test_autolambda(self, model_fn):
|
def test_autolambda(self, model_fn):
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
@ -208,7 +222,11 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||||||
model(np_inputs) # Test calling the model directly on inputs.
|
model(np_inputs) # Test calling the model directly on inputs.
|
||||||
|
|
||||||
new_model = keras.Model.from_config(
|
new_model = keras.Model.from_config(
|
||||||
model.get_config(), custom_objects={'LayerWithLayer': LayerWithLayer})
|
model.get_config(),
|
||||||
|
custom_objects={
|
||||||
|
'LayerWithLayer': LayerWithLayer,
|
||||||
|
'MyAdd': MyAdd
|
||||||
|
})
|
||||||
new_model.compile(
|
new_model.compile(
|
||||||
adam.Adam(0.001),
|
adam.Adam(0.001),
|
||||||
'mse',
|
'mse',
|
||||||
|
@ -53,6 +53,11 @@ NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
|
|||||||
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
|
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
|
||||||
NVCC_VERSION = '10.0'
|
NVCC_VERSION = '10.0'
|
||||||
|
|
||||||
|
# Environment variable for supported TF CUDA Compute Capabilities
|
||||||
|
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
|
||||||
|
CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
|
||||||
|
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
|
||||||
|
|
||||||
def Log(s):
|
def Log(s):
|
||||||
print('gpus/crosstool: {0}'.format(s))
|
print('gpus/crosstool: {0}'.format(s))
|
||||||
|
|
||||||
@ -202,7 +207,7 @@ def InvokeNvcc(argv, log=False):
|
|||||||
srcs = ' '.join(src_files)
|
srcs = ' '.join(src_files)
|
||||||
out = ' -o ' + out_file[0]
|
out = ' -o ' + out_file[0]
|
||||||
|
|
||||||
supported_cuda_compute_capabilities = [ "3.0", "6.0" ]
|
supported_cuda_compute_capabilities = os.environ.get(CUDA_COMPUTE_ENV_VAR, DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
|
||||||
nvccopts = '-D_FORCE_INLINES '
|
nvccopts = '-D_FORCE_INLINES '
|
||||||
for capability in supported_cuda_compute_capabilities:
|
for capability in supported_cuda_compute_capabilities:
|
||||||
capability = capability.replace('.', '')
|
capability = capability.replace('.', '')
|
||||||
|
@ -36,7 +36,14 @@ GCC_HOST_COMPILER_PATH = ('/dt7/usr/bin/gcc')
|
|||||||
NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
|
NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
|
||||||
NVCC_VERSION = '10.0'
|
NVCC_VERSION = '10.0'
|
||||||
NVCC_TEMP_DIR = "C:\\Windows\\Temp\\nvcc_inter_files_tmp_dir"
|
NVCC_TEMP_DIR = "C:\\Windows\\Temp\\nvcc_inter_files_tmp_dir"
|
||||||
supported_cuda_compute_capabilities = [ "3.0", "6.0" ]
|
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
|
||||||
|
|
||||||
|
# Taken from environment variable for supported TF CUDA Compute Capabilities
|
||||||
|
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
|
||||||
|
supported_cuda_compute_capabilities = os.environ.get(
|
||||||
|
'TF_CUDA_COMPUTE_CAPABILITIES',
|
||||||
|
DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
|
||||||
|
|
||||||
|
|
||||||
def Log(s):
|
def Log(s):
|
||||||
print('gpus/crosstool: {0}'.format(s))
|
print('gpus/crosstool: {0}'.format(s))
|
||||||
|
Loading…
Reference in New Issue
Block a user