Support Tensorflow ops in Keras functional API for TPUs
TPU rewriting adds extra metadata attributes to graph nodes that is not captured by the TensorflowOpLayer. In particular, ops and constants require the `_tpu_replicate`, which are usually added in the rewrite step. Reconstructing the constants and post-processing the ops with the control flow context covers these cases. PiperOrigin-RevId: 257226841
This commit is contained in:
parent
f72dc13764
commit
eebe2f7247
tensorflow/python/keras
@ -262,6 +262,8 @@ py_library(
|
||||
":constraints",
|
||||
":engine_utils",
|
||||
":regularizers",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
@ -2401,21 +2402,30 @@ class TensorFlowOpLayer(Layer):
|
||||
return self._defun_call(inputs)
|
||||
return self._make_op(inputs)
|
||||
|
||||
def _make_node_def(self, graph):
|
||||
node_def = node_def_pb2.NodeDef()
|
||||
node_def.CopyFrom(self.node_def)
|
||||
node_def.name = graph.unique_name(node_def.name)
|
||||
return node_def
|
||||
|
||||
def _make_op(self, inputs):
|
||||
inputs = nest.flatten(inputs)
|
||||
graph = inputs[0].graph
|
||||
node_def = self._make_node_def(graph)
|
||||
with graph.as_default():
|
||||
for index, constant in self.constants.items():
|
||||
constant = ops.convert_to_tensor(constant)
|
||||
# Recreate constant in graph to add distribution context.
|
||||
value = tensor_util.constant_value(constant)
|
||||
if value is not None:
|
||||
constant = constant_op.constant(value, name=node_def.input[index])
|
||||
inputs.insert(index, constant)
|
||||
|
||||
self.node_def.name = graph.unique_name(self.node_def.name)
|
||||
# Check for case where first input should be a list of Tensors.
|
||||
if 'N' in self.node_def.attr:
|
||||
num_tensors = self.node_def.attr['N'].i
|
||||
if 'N' in node_def.attr:
|
||||
num_tensors = node_def.attr['N'].i
|
||||
inputs = [inputs[:num_tensors]] + inputs[num_tensors:]
|
||||
c_op = ops._create_c_op(graph, self.node_def, inputs, control_inputs=[])
|
||||
c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[])
|
||||
op = graph._create_op_from_tf_operation(c_op)
|
||||
op._control_flow_post_processing()
|
||||
|
||||
# Record the gradient because custom-made ops don't go through the
|
||||
# code-gen'd eager call path
|
||||
@ -2426,8 +2436,7 @@ class TensorFlowOpLayer(Layer):
|
||||
attrs.append(attr_name)
|
||||
attrs.append(op.get_attr(attr_name))
|
||||
attrs = tuple(attrs)
|
||||
execute.record_gradient(op_type, op.inputs, attrs, op.outputs,
|
||||
op.name)
|
||||
execute.record_gradient(op_type, op.inputs, attrs, op.outputs, op.name)
|
||||
|
||||
if len(op.outputs) == 1:
|
||||
return op.outputs[0]
|
||||
|
Loading…
Reference in New Issue
Block a user