From eebe2f7247ff8ece3864bb100f3030f0168d1120 Mon Sep 17 00:00:00 2001 From: Philip Pham Date: Tue, 9 Jul 2019 10:48:13 -0700 Subject: [PATCH] Support Tensorflow ops in Keras functional API for TPUs TPU rewriting adds extra metadata attributes to graph nodes that is not captured by the TensorflowOpLayer. In particular, ops and constants require the `_tpu_replicate`, which are usually added in the rewrite step. Reconstructing the constants and post-processing the ops with the control flow context covers these cases. PiperOrigin-RevId: 257226841 --- tensorflow/python/keras/BUILD | 2 ++ tensorflow/python/keras/engine/base_layer.py | 25 +++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 3da6e3f4efa..fb4bb82d197 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -262,6 +262,8 @@ py_library( ":constraints", ":engine_utils", ":regularizers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", "//tensorflow/python/data", "//tensorflow/python/distribute:distribute_coordinator", "//tensorflow/python/distribute:distribute_lib", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index c5bc2d2c219..a7b1721335c 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -37,6 +37,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import function from tensorflow.python.framework import auto_control_deps +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops @@ -2401,21 +2402,30 @@ class TensorFlowOpLayer(Layer): return self._defun_call(inputs) return self._make_op(inputs) + def _make_node_def(self, graph): + node_def = node_def_pb2.NodeDef() + node_def.CopyFrom(self.node_def) + node_def.name = graph.unique_name(node_def.name) + return node_def + def _make_op(self, inputs): inputs = nest.flatten(inputs) graph = inputs[0].graph + node_def = self._make_node_def(graph) with graph.as_default(): for index, constant in self.constants.items(): - constant = ops.convert_to_tensor(constant) + # Recreate constant in graph to add distribution context. + value = tensor_util.constant_value(constant) + if value is not None: + constant = constant_op.constant(value, name=node_def.input[index]) inputs.insert(index, constant) - - self.node_def.name = graph.unique_name(self.node_def.name) # Check for case where first input should be a list of Tensors. - if 'N' in self.node_def.attr: - num_tensors = self.node_def.attr['N'].i + if 'N' in node_def.attr: + num_tensors = node_def.attr['N'].i inputs = [inputs[:num_tensors]] + inputs[num_tensors:] - c_op = ops._create_c_op(graph, self.node_def, inputs, control_inputs=[]) + c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[]) op = graph._create_op_from_tf_operation(c_op) + op._control_flow_post_processing() # Record the gradient because custom-made ops don't go through the # code-gen'd eager call path @@ -2426,8 +2436,7 @@ class TensorFlowOpLayer(Layer): attrs.append(attr_name) attrs.append(op.get_attr(attr_name)) attrs = tuple(attrs) - execute.record_gradient(op_type, op.inputs, attrs, op.outputs, - op.name) + execute.record_gradient(op_type, op.inputs, attrs, op.outputs, op.name) if len(op.outputs) == 1: return op.outputs[0]