Removed internal calls to @RegisterShape and related APIs

All of that code is effectively dead since shape inference happens in C++.

PiperOrigin-RevId: 269904930
This commit is contained in:
Sergei Lebedev 2019-09-18 15:31:08 -07:00 committed by TensorFlower Gardener
parent 3ef0d0d074
commit 205bf5260c
10 changed files with 14 additions and 717 deletions

View File

@ -77,7 +77,6 @@ tf_custom_op_py_library(
":audio_microfrontend_op", ":audio_microfrontend_op",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:common_shapes",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",

View File

@ -776,7 +776,6 @@ py_library(
":_pywrap_tfprof", ":_pywrap_tfprof",
":_pywrap_util_port", ":_pywrap_util_port",
":_pywrap_utils", ":_pywrap_utils",
":common_shapes",
":composite_tensor", ":composite_tensor",
":convert_to_constants", ":convert_to_constants",
":cpp_shape_inference_proto_py", ":cpp_shape_inference_proto_py",
@ -815,13 +814,7 @@ py_library(
srcs = ["framework/common_shapes.py"], srcs = ["framework/common_shapes.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":cpp_shape_inference_proto_py",
":errors",
":framework_ops",
":pywrap_tensorflow",
":tensor_shape", ":tensor_shape",
":tensor_util",
"//tensorflow/core:protos_all_py",
], ],
) )

View File

@ -26,7 +26,6 @@ import numpy as np
from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -43,11 +42,6 @@ from tensorflow.python.platform import test
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
def testClusterSpecPropagationSimple(self): def testClusterSpecPropagationSimple(self):

View File

@ -22,11 +22,9 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -34,11 +32,6 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
class PartialRunTest(test_util.TensorFlowTestCase): class PartialRunTest(test_util.TensorFlowTestCase):
def RunTestPartialRun(self, sess): def RunTestPartialRun(self, sess):

View File

@ -34,7 +34,6 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as framework_device_lib from tensorflow.python.framework import device as framework_device_lib
@ -69,11 +68,6 @@ except ImportError:
attr = None attr = None
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
class SessionTest(test_util.TensorFlowTestCase): class SessionTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):

View File

@ -17,498 +17,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import six
import six.moves
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
def has_fully_defined_shape(tensor):
"""Returns true if tensor has a fully defined shape."""
return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
def rank(tensor):
"""Return a rank if it is a tensor, else return None."""
if isinstance(tensor, ops.Tensor):
return tensor._rank() # pylint: disable=protected-access
return None
def scalar_shape(unused_op):
"""Shape function for ops that output a scalar value."""
return [tensor_shape.TensorShape([])]
def unchanged_shape(op):
"""Shape function for ops that output a tensor like their first input."""
return [op.inputs[0].get_shape()]
def unchanged_shape_with_rank(rank):
"""Returns a shape function for ops that constrain the rank of their input.
Args:
rank: The exact rank of the input and output.
Returns:
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank(rank)]
return _ShapeFunction
def unchanged_shape_with_rank_at_least(rank):
"""Returns a shape function for ops that constrain the rank of their input.
Args:
rank: A lower bound on the rank of the input and output.
Returns:
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank_at_least(rank)]
return _ShapeFunction
def unchanged_shape_with_rank_at_most(rank):
"""Returns a shape function for ops that constrain the rank of their input.
Args:
rank: An upper bound on the rank of the input and output.
Returns:
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank_at_most(rank)]
return _ShapeFunction
def matmul_shape(op):
"""Shape function for a MatMul op."""
a_shape = op.inputs[0].get_shape().with_rank(2)
transpose_a = op.get_attr("transpose_a")
b_shape = op.inputs[1].get_shape().with_rank(2)
transpose_b = op.get_attr("transpose_b")
output_rows = a_shape[1] if transpose_a else a_shape[0]
output_cols = b_shape[0] if transpose_b else b_shape[1]
inner_a = a_shape[0] if transpose_a else a_shape[1]
inner_b = b_shape[1] if transpose_b else b_shape[0]
inner_a.assert_is_compatible_with(inner_b)
return [tensor_shape.TensorShape([output_rows, output_cols])]
def get_conv_output_size(input_size, filter_size, strides, padding_type):
"""Returns the spatial size of a n-d convolution/pooling output."""
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
strides = [int(x) for x in strides]
if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
return input_size
if any(x is not None and y is not None and x > y for x, y in
zip(filter_size, input_size)):
raise ValueError("Filter must not be larger than the input: "
"Filter: %r Input: %r" % (filter_size, input_size))
if padding_type == b"VALID":
def _valid(in_dim, k_dim, s_dim):
if in_dim is not None and k_dim is not None:
return (in_dim - k_dim + s_dim) // s_dim
else:
return None
output_size = [
_valid(in_dim, k_dim, s_dim)
for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
]
elif padding_type == b"SAME":
def _same(in_dim, s_dim):
if in_dim is not None:
return (in_dim + s_dim - 1) // s_dim
else:
return None
output_size = [_same(in_dim, s_dim)
for in_dim, s_dim in zip(input_size, strides)]
else:
raise ValueError("Invalid padding: %r" % padding_type)
return tuple(output_size)
def get2d_conv_output_size(input_height, input_width, filter_height,
filter_width, row_stride, col_stride, padding_type):
"""Returns the number of rows and columns in a convolution/pooling output."""
return get_conv_output_size((input_height, input_width),
(filter_height, filter_width),
(row_stride, col_stride), padding_type)
def conv2d_shape(op):
"""Shape function for a Conv2D op.
This op has two inputs:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
* filter, a 4D tensor with shape = [filter_rows, filter_cols,
depth_in, depth_out]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows and out_cols depend on the
value of the op's "padding" and "strides" attrs.
Args:
op: A Conv2D Operation.
Returns:
A list containing the Shape of the Conv2D output.
Raises:
ValueError: If the shapes of the input or filter are incompatible.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
filter_shape = op.inputs[1].get_shape().with_rank(4)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Convert input shape to the default NHWC for inference.
input_shape = [input_shape[0], input_shape[2], input_shape[3],
input_shape[1]]
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
filter_rows = filter_shape[0]
filter_cols = filter_shape[1]
depth_out = filter_shape[3]
# Check that the input depths are compatible.
input_shape[3].assert_is_compatible_with(filter_shape[2])
if data_format == b"NCHW":
stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
else:
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
filter_cols, stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth_out]
if data_format == b"NCHW":
# Convert output shape back to NCHW.
output_shape = [output_shape[0], output_shape[3], output_shape[1],
output_shape[2]]
return [tensor_shape.TensorShape(output_shape)]
def depthwise_conv2d_native_shape(op):
"""Shape function for a DepthwiseConv2D op.
This op has two inputs:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
* filter, a 4D tensor with shape = [filter_rows, filter_cols,
depth_in, depthwise_multiplier]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend
on the value of the op's "padding" and "strides" attrs.
Args:
op: A DepthwiseConv2dNative Operation.
Returns:
A list containing the Shape of the DepthwiseConv2DNative output.
Raises:
ValueError: If the shapes of the input or filter are incompatible.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
filter_shape = op.inputs[1].get_shape().with_rank(4)
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
filter_rows = filter_shape[0]
filter_cols = filter_shape[1]
depth_out = filter_shape[3] * filter_shape[2]
# Check that the input depths are compatible.
input_shape[3].assert_is_compatible_with(filter_shape[2])
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
if stride_r != stride_c:
# TODO(shlens): Add support for this.
raise ValueError("Current implementation only supports equal length "
"strides in the row and column dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
stride = stride_r
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
filter_cols, stride, stride,
padding)
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
def separable_conv2d_shape(op):
"""Shape function for a SeparableConv2D op.
This op has three inputs:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
* depthwise_filter, a 4D tensor with shape = [filter_rows,
filter_cols, depth_in, depth_multiplier]
* pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
depth_multiplier, depth_out]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows and out_cols depend on the
value of the op's "padding" and "strides" attrs.
Args:
op: A SeparableConv2D Operation.
Returns:
A list containing the Shape of the SeparableConv2D output.
Raises:
ValueError: If the shapes of the input or filter are incompatible.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.TensorShape([None, None, input_shape[3], None]))
pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
filter_rows = depthwise_filter_shape[0]
filter_cols = depthwise_filter_shape[1]
depth_out = pointwise_filter_shape[3]
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
if stride_r != stride_c:
# TODO(shlens): Add support for this.
raise ValueError("Current implementation only supports equal length "
"strides in the row and column dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
stride = stride_r
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
filter_cols, stride, stride,
padding)
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
def avg_pool_shape(op):
"""Shape function for an AvgPool op.
This op has one input:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows and out_cols depend on the
value of the op's "ksize", "strides", and "padding" attrs.
Args:
op: An AvgPool Operation.
Returns:
A single-element list containing the Shape of the AvgPool output.
Raises:
ValueError: If the shape of the input is invalid or incompatible with
the values of the attrs.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Convert input shape to the default NHWC for inference.
input_shape = [input_shape[0], input_shape[2], input_shape[3],
input_shape[1]]
if data_format == b"NCHW":
ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
else:
ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
depth = input_shape[3]
if ksize_b != 1 or ksize_d != 1:
raise ValueError("Current implementation does not support pooling "
"in the batch and depth dimensions.")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not support strides "
"in the batch and depth dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
ksize_c, stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth]
if data_format == b"NCHW":
# Convert output shape back to NCHW.
output_shape = [output_shape[0], output_shape[3], output_shape[1],
output_shape[2]]
return [tensor_shape.TensorShape(output_shape)]
def max_pool_shape(op):
"""Shape function for a MaxPool op.
This op has one input:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows, out_cols, and depth_out depend
on the value of the op's "ksize", "strides", and "padding" attrs.
Args:
op: A MaxPool Operation.
Returns:
A single-element list containing the Shape of the MaxPool output.
Raises:
ValueError: If the shape of the input is invalid or incompatible with
the values of the attrs.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Convert input shape to the default NHWC for inference.
input_shape = [input_shape[0], input_shape[2], input_shape[3],
input_shape[1]]
if data_format == b"NCHW":
ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
else:
ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
depth = input_shape[3]
if ksize_b != 1:
raise ValueError("Current implementation does not support pooling "
"in the batch dimension.")
if stride_b != 1:
raise ValueError("Current implementation does not support strides "
"in the batch dimension.")
if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
raise ValueError("MaxPooling supports exactly one of pooling across depth "
"or pooling across width/height.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
if ksize_d == 1:
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
ksize_c, stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth]
else:
if depth % ksize_d > 0:
raise ValueError("Depthwise max pooling requires the depth window "
"to evenly divide the input depth.")
if stride_d != ksize_d:
raise ValueError("Depthwise max pooling requires the depth window "
"to equal the depth stride.")
output_shape = [batch_size, in_rows, in_cols, depth // ksize_d]
if data_format == b"NCHW":
# Convert output shape back to NCHW.
output_shape = [output_shape[0], output_shape[3], output_shape[1],
output_shape[2]]
return [tensor_shape.TensorShape(output_shape)]
def no_outputs(unused_op):
"""Shape function for use with ops that have no outputs."""
return []
def unknown_shape(op):
"""Shape function for use with ops whose output shapes are unknown."""
return [tensor_shape.unknown_shape() for _ in op.outputs]
def _broadcast_shape_helper(shape_x, shape_y): def _broadcast_shape_helper(shape_x, shape_y):
@ -595,136 +106,3 @@ def broadcast_shape(shape_x, shape_y):
raise ValueError("Incompatible shapes for broadcasting: %s and %s" raise ValueError("Incompatible shapes for broadcasting: %s and %s"
% (shape_x, shape_y)) % (shape_x, shape_y))
return tensor_shape.TensorShape(return_dims) return tensor_shape.TensorShape(return_dims)
def call_cpp_shape_fn(op, require_shape_fn=True):
"""A shape function that delegates to the registered C++ shape function.
Args:
op: the node in the graph for which to compute output shapes.
require_shape_fn: If true, and the C++ shape function is not registered
in the current binary then an exception is raised; otherwise, if the
C++ shape function is not registered then unknown_shape is used.
Returns:
A dictionary with the following keys:
shapes: A TensorShape list of the output shapes of the op, as computed
using the C++ shape inference function registered for the op.
handle_shapes: A TensorShape list of the shapes for handle outputs, if
any.
handle_dtypes: A list of DataType enums for the handle outputs, if any.
Raises:
ValueError: If the C++ shape function returned an error (e.g. because the
shapes of the inputs are of the wrong rank or otherwise incompatible
according to the shape function).
RuntimeError: If the C++ shape function is not registered and
<require_shape_fn> is True.
"""
if op.type == "Const":
# To avoid serializing large constants, we special-case constant
# here, even though it has a C++ shape function. When Python
# calls the C / C-API directly, we should be able to remove this.
return {
"shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
"handle_data": [None]
}
input_tensors_needed = []
input_tensors_as_shapes_needed = []
while True:
res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
input_tensors_as_shapes_needed,
require_shape_fn)
if not isinstance(res, dict):
# Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
return res
# See if we need to evaluate some inputs.
if not res["inputs_needed"]:
return res
p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded()
p = p.FromString(res["inputs_needed"])
changed = False
for idx in p.input_tensors_needed:
if idx not in input_tensors_needed:
input_tensors_needed.append(idx)
changed = True
for idx in p.input_tensors_as_shapes_needed:
if idx not in input_tensors_as_shapes_needed:
input_tensors_as_shapes_needed.append(idx)
changed = True
if not changed:
return res
def _call_cpp_shape_fn_impl(
op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
"""Core implementation of call_cpp_shape_fn."""
graph_def_version = op.graph.graph_def_versions.producer
node_def_str = op.node_def.SerializeToString()
def tensor_to_inference_result(t):
r = cpp_shape_inference_pb2.CppShapeInferenceResult()
r.shape.CopyFrom(t.get_shape().as_proto())
# pylint: disable=protected-access
if t._handle_data is not None:
r.handle_data.CopyFrom(t._handle_data)
# pylint: enable=protected-access
return r.SerializeToString()
input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
input_tensors = [None for i in input_shapes]
for idx in input_tensors_needed:
v = tensor_util.constant_value(op.inputs[idx])
if v is not None:
input_tensors[idx] = np.asarray(v)
serialized_unknown_shape = (
tensor_shape.TensorShape(None).as_proto().SerializeToString())
arr = [serialized_unknown_shape for i in input_shapes]
for idx in input_tensors_as_shapes_needed:
s = tensor_util.constant_value_as_shape(op.inputs[idx])
if s is not None:
arr[idx] = s.as_proto().SerializeToString()
input_tensors_as_shapes = arr
missing_shape_fn = False
try:
output = pywrap_tensorflow.RunCppShapeInference(
graph_def_version, node_def_str, input_shapes, input_tensors,
input_tensors_as_shapes)
except errors.InvalidArgumentError as err:
if err.message.startswith("No shape inference function exists for op"):
missing_shape_fn = True
else:
raise ValueError(err.message)
if missing_shape_fn:
if require_shape_fn:
raise RuntimeError(
"No C++ shape function registered for standard op: %s" % op.type)
return unknown_shape(op)
output_shapes = output[:-1]
# Convert TensorShapeProto values in output_shapes.
result_protos = [
cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
for s in output_shapes
]
result = [r.shape for r in result_protos]
result_handle_data = [
r.handle_data if r.handle_data.is_set else None for r in result_protos
]
return {
"shapes": result,
"handle_data": result_handle_data,
"inputs_needed": output[-1]
}
# pylint: disable=protected-access
ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
# pylint: enable=protected-access

View File

@ -59,7 +59,6 @@ from tensorflow.python.framework.tensor_util import MakeNdarray as make_ndarray
from tensorflow.python.framework.ops import RegisterGradient from tensorflow.python.framework.ops import RegisterGradient
from tensorflow.python.framework.ops import NotDifferentiable from tensorflow.python.framework.ops import NotDifferentiable
from tensorflow.python.framework.ops import NoGradient from tensorflow.python.framework.ops import NoGradient
from tensorflow.python.framework.ops import RegisterShape
from tensorflow.python.framework.tensor_shape import Dimension from tensorflow.python.framework.tensor_shape import Dimension
from tensorflow.python.framework.tensor_shape import TensorShape from tensorflow.python.framework.tensor_shape import TensorShape

View File

@ -2534,33 +2534,6 @@ def get_gradient_function(op):
return _gradient_registry.lookup(op_type) return _gradient_registry.lookup(op_type)
_shape_registry = registry.Registry("shape functions")
_default_shape_function_registry = registry.Registry("default shape functions")
# These are set to common_shapes.call_cpp_shape_fn by op generated code
# (generated by python_op_gen.cc).
# It is set outside ops.py to avoid a circular dependency.
_call_cpp_shape_fn = None
_call_cpp_shape_fn_and_require_op = None
def _set_call_cpp_shape_fn(call_cpp_shape_fn):
"""Sets default shape fns from passed common_shapes.call_cpp_shape_fn."""
global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op
if _call_cpp_shape_fn:
return # already registered
def call_without_requiring(op):
return call_cpp_shape_fn(op, require_shape_fn=False)
_call_cpp_shape_fn = call_without_requiring
def call_with_requiring(op):
return call_cpp_shape_fn(op, require_shape_fn=True)
_call_cpp_shape_fn_and_require_op = call_with_requiring
class RegisterShape(object): class RegisterShape(object):
"""No longer used. """No longer used.
@ -2575,26 +2548,9 @@ class RegisterShape(object):
"""Saves the `op_type` as the `Operation` type.""" """Saves the `op_type` as the `Operation` type."""
if not isinstance(op_type, six.string_types): if not isinstance(op_type, six.string_types):
raise TypeError("op_type must be a string") raise TypeError("op_type must be a string")
self._op_type = op_type
def __call__(self, f): def __call__(self, f):
"""Registers "f" as the shape function for "op_type".""" """No-op."""
if f is None:
assert _call_cpp_shape_fn
# None is a special "weak" value that provides a default shape function,
# and can be overridden by a non-None registration.
try:
_default_shape_function_registry.register(_call_cpp_shape_fn,
self._op_type)
except KeyError:
# Ignore duplicate registrations of the weak value. This can
# occur if the op library input to wrapper generation
# inadvertently links in one or more of the standard op
# libraries.
pass
else:
_shape_registry.register(f, self._op_type)
return f return f

View File

@ -35,7 +35,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import wrap_function from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev from tensorflow.python.framework import device as pydev
@ -64,8 +63,6 @@ import tensorflow.python.ops.gradients # pylint: disable=unused-import
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from tensorflow.python.util import compat from tensorflow.python.util import compat
ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
class ResourceTest(test_util.TensorFlowTestCase): class ResourceTest(test_util.TensorFlowTestCase):
@ -2907,9 +2904,6 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
self.assertAllEqual((None, None), a7) self.assertAllEqual((None, None), a7)
ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
class KernelLabelTest(test_util.TensorFlowTestCase): class KernelLabelTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1

View File

@ -75,7 +75,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.compat import compat as fwd_compat from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
@ -1434,11 +1433,12 @@ def _ReductionDims(x, axis, reduction_indices=None): # pylint: disable=invalid-
return axis return axis
else: else:
# Fast path: avoid creating Rank and Range ops if ndims is known. # Fast path: avoid creating Rank and Range ops if ndims is known.
rank = common_shapes.rank(x) if isinstance(x, ops.Tensor):
if rank is not None: rank = x.shape.rank
return constant_op.constant(np.arange(rank), dtype=dtypes.int32) if rank is not None:
if (isinstance(x, sparse_tensor.SparseTensor) and return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
x.dense_shape.shape.is_fully_defined()): elif (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.shape.is_fully_defined()):
rank = x.dense_shape.shape.dims[0].value # sparse.dense_shape is 1-D. rank = x.dense_shape.shape.dims[0].value # sparse.dense_shape is 1-D.
return constant_op.constant(np.arange(rank), dtype=dtypes.int32) return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
@ -1446,9 +1446,14 @@ def _ReductionDims(x, axis, reduction_indices=None): # pylint: disable=invalid-
return range(0, array_ops.rank(x)) return range(0, array_ops.rank(x))
def _has_fully_defined_shape(tensor):
"""Returns true if tensor has a fully defined shape."""
return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
def _may_reduce_to_scalar(keepdims, axis, output): def _may_reduce_to_scalar(keepdims, axis, output):
"""Set a reduction's output shape to be a scalar if we are certain.""" """Set a reduction's output shape to be a scalar if we are certain."""
if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and ( if not _has_fully_defined_shape(output) and (not keepdims) and (
axis is None): axis is None):
output.set_shape(()) output.set_shape(())
return output return output
@ -3438,14 +3443,6 @@ def conj(x, name=None):
x.dtype) x.dtype)
def _BroadcastShape(op):
"""Common shape function for binary operators that broadcast their inputs."""
return [
common_shapes.broadcast_shape(op.inputs[0].get_shape(),
op.inputs[1].get_shape())
]
def reduced_shape(input_shape, axes): def reduced_shape(input_shape, axes):
"""Helper function for reduction ops. """Helper function for reduction ops.