diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 56e83f03cce..ee0bb4a55a8 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -430,7 +430,7 @@ class FuncGraph(ops.Graph): compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, context.context()) else: - op = ops.get_default_graph().create_op( + op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access op_type, uncaptured_inputs, dtypes, @@ -438,7 +438,7 @@ class FuncGraph(ops.Graph): name, attrs, op_def, - compute_device=compute_device) + compute_device) value = op.outputs[0] captured_value = self.capture(value) return captured_value.op @@ -510,9 +510,9 @@ class FuncGraph(ops.Graph): inp = ctxt.AddValue(inp) inp = self.capture(inp) inputs[i] = inp - return super(FuncGraph, self).create_op( + return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_device=compute_device) + compute_device) def capture(self, tensor, name=None): """Captures `tensor` if it's external to this graph. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 700431375f9..014167f6d06 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3269,11 +3269,55 @@ class Graph(object): An `Operation` object. """ del compute_shapes - - self._check_not_finalized() for idx, a in enumerate(inputs): if not isinstance(a, Tensor): raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) + return self._create_op_internal(op_type, inputs, dtypes, input_types, name, + attrs, op_def, compute_device) + + def _create_op_internal( + self, + op_type, + inputs, + dtypes=None, # pylint: disable=redefined-outer-name + input_types=None, + name=None, + attrs=None, + op_def=None, + compute_device=True): + """Creates an `Operation` in this graph. + + Implements `Graph.create_op()` without the overhead of the deprecation + wrapper. + + Args: + op_type: The `Operation` type to create. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + inputs: A list of `Tensor` objects that will be inputs to the `Operation`. + dtypes: (Optional) A list of `DType` objects that will be the types of the + tensors that the operation produces. + input_types: (Optional.) A list of `DType`s that will be the types of the + tensors that the operation consumes. By default, uses the base `DType` + of each input in `inputs`. Operations that expect reference-typed inputs + must specify `input_types` explicitly. + name: (Optional.) A string name for the operation. If not specified, a + name is generated based on `op_type`. + attrs: (Optional.) A dictionary where the key is the attribute name (a + string) and the value is the respective `attr` attribute of the + `NodeDef` proto that will represent the operation (an `AttrValue` + proto). + op_def: (Optional.) The `OpDef` proto that describes the `op_type` that + the operation will have. + compute_device: (Optional.) If True, device functions will be executed to + compute the device property of the Operation. + + Raises: + ValueError: if colocation conflicts with existing device assignment. + + Returns: + An `Operation` object. + """ + self._check_not_finalized() if name is None: name = op_type # If a names ends with a '/' it is a "name scope" and we use it as-is,