Support Tensorflow ops in Keras functional API for TPUs

TPU rewriting adds extra metadata attributes to graph nodes that is not captured
by the TensorflowOpLayer. In particular, ops and constants require the
`_tpu_replicate`, which are usually added in the rewrite step. Reconstructing
the constants and post-processing the ops with the control flow context covers
these cases.

PiperOrigin-RevId: 257226841
This commit is contained in:
Philip Pham 2019-07-09 10:48:13 -07:00 committed by TensorFlower Gardener
parent f72dc13764
commit eebe2f7247
2 changed files with 19 additions and 8 deletions
tensorflow/python/keras

View File

@ -262,6 +262,8 @@ py_library(
":constraints",
":engine_utils",
":regularizers",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:constant_op",
"//tensorflow/python/data",
"//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_lib",

View File

@ -37,6 +37,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import function
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@ -2401,21 +2402,30 @@ class TensorFlowOpLayer(Layer):
return self._defun_call(inputs)
return self._make_op(inputs)
def _make_node_def(self, graph):
node_def = node_def_pb2.NodeDef()
node_def.CopyFrom(self.node_def)
node_def.name = graph.unique_name(node_def.name)
return node_def
def _make_op(self, inputs):
inputs = nest.flatten(inputs)
graph = inputs[0].graph
node_def = self._make_node_def(graph)
with graph.as_default():
for index, constant in self.constants.items():
constant = ops.convert_to_tensor(constant)
# Recreate constant in graph to add distribution context.
value = tensor_util.constant_value(constant)
if value is not None:
constant = constant_op.constant(value, name=node_def.input[index])
inputs.insert(index, constant)
self.node_def.name = graph.unique_name(self.node_def.name)
# Check for case where first input should be a list of Tensors.
if 'N' in self.node_def.attr:
num_tensors = self.node_def.attr['N'].i
if 'N' in node_def.attr:
num_tensors = node_def.attr['N'].i
inputs = [inputs[:num_tensors]] + inputs[num_tensors:]
c_op = ops._create_c_op(graph, self.node_def, inputs, control_inputs=[])
c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[])
op = graph._create_op_from_tf_operation(c_op)
op._control_flow_post_processing()
# Record the gradient because custom-made ops don't go through the
# code-gen'd eager call path
@ -2426,8 +2436,7 @@ class TensorFlowOpLayer(Layer):
attrs.append(attr_name)
attrs.append(op.get_attr(attr_name))
attrs = tuple(attrs)
execute.record_gradient(op_type, op.inputs, attrs, op.outputs,
op.name)
execute.record_gradient(op_type, op.inputs, attrs, op.outputs, op.name)
if len(op.outputs) == 1:
return op.outputs[0]