Simplify TensorShape v1/v2 switch (and fix bugs in it). Note that this means that whenever v2 behavior is enabled, all functions (including compat.v1 functions) will return

v2 TensorShapes. TensorShape behavior is switched as a global, and not per function.

PiperOrigin-RevId: 233894962
This commit is contained in:
Martin Wicke 2019-02-13 22:41:38 -08:00 committed by TensorFlower Gardener
parent 368674dfe2
commit 5979904503
6 changed files with 15 additions and 37 deletions
tensorflow

View File

@ -718,15 +718,15 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
"""Recalculates the output_shapes after dividing it by num_workers."""
if len(output_shapes) < 1:
raise ValueError("Input shape should have at least one dimension.")
if (output_shapes.dims[0].value and
output_shapes.dims[0].value % num_workers != 0):
if (tensor_shape.dimension_value(output_shapes[0]) and
tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
raise errors.InvalidArgumentError(
None, None,
"First dim of input shape: %d is not divisible by num_workers: %d" %
(output_shapes[0], num_workers))
output_dims = [d for d in output_shapes.dims]
output_dims[0] = output_dims[0] // num_workers
return tensor_shape.TensorShapeV1(output_dims)
return tensor_shape.TensorShape(output_dims)
output_shapes = nest.map_structure(recalculate_output_shapes,
input_dataset.output_shapes)

View File

@ -74,9 +74,8 @@ def enable_v2_tensorshape():
# in `tensor_shape[i]`, but they would not be.
```
"""
global _TENSORSHAPE_V2_OVERRIDE, TensorShape # pylint: disable=invalid-name
global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name
_TENSORSHAPE_V2_OVERRIDE = True
TensorShape = TensorShapeV2
@tf_export(v1=["disable_v2_tensorshape"])
@ -85,9 +84,8 @@ def disable_v2_tensorshape():
See docstring for `enable_v2_tensorshape` for details about the new behavior.
"""
global _TENSORSHAPE_V2_OVERRIDE, TensorShape # pylint: disable=invalid-name
global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name
_TENSORSHAPE_V2_OVERRIDE = False
TensorShape = TensorShapeV1
@tf_export("compat.dimension_value",
@ -635,8 +633,8 @@ def as_dimension(value):
return Dimension(value)
@tf_export(v1=["TensorShape"])
class TensorShapeV1(object):
@tf_export("TensorShape")
class TensorShape(object):
"""Represents the shape of a `Tensor`.
A `TensorShape` represents a possibly-partial shape specification for a
@ -695,7 +693,7 @@ class TensorShapeV1(object):
@property
def _v2_behavior(self):
if _TENSORSHAPE_V2_OVERRIDE is None:
return False
return tf2.enabled()
return _TENSORSHAPE_V2_OVERRIDE
def __repr__(self):
@ -1151,22 +1149,6 @@ def unknown_shape(rank=None, **kwargs):
return TensorShape([Dimension(None)] * rank)
@tf_export("TensorShape", v1=[])
class TensorShapeV2(TensorShapeV1):
@property
def _v2_behavior(self):
if _TENSORSHAPE_V2_OVERRIDE is None:
return True
return _TENSORSHAPE_V2_OVERRIDE
if tf2.enabled():
TensorShape = TensorShapeV2
else:
TensorShape = TensorShapeV1
def scalar():
"""Returns a shape representing a scalar."""
return TensorShape([])

View File

@ -928,16 +928,16 @@ def convolution_internal(
name=None):
"""Internal function which performs rank agnostic convolution."""
with ops.name_scope(name, "convolution", [input, filters]) as name:
if isinstance(input.shape, tensor_shape.TensorShapeV1) and \
if isinstance(input.shape, tensor_shape.TensorShape) and \
input.shape.rank is not None:
n = len(input.shape) - 2
elif not isinstance(input.shape, tensor_shape.TensorShapeV1) and \
elif not isinstance(input.shape, tensor_shape.TensorShape) and \
input.shape is not None:
n = len(input.shape) - 2
elif isinstance(filters.shape, tensor_shape.TensorShapeV1) and \
elif isinstance(filters.shape, tensor_shape.TensorShape) and \
filters.shape.rank is not None:
n = len(filters.shape) - 2
elif not isinstance(filters.shape, tensor_shape.TensorShapeV1) and \
elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
filters.shape is not None:
n = len(filters.shape) - 2
else:

View File

@ -361,10 +361,7 @@ class _TensorShapeCodec(object):
"""Codec for `TensorShape`."""
def can_encode(self, pyobj):
return isinstance(pyobj, (tensor_shape.TensorShape,
# TODO(b/121255889): Should not need these.
tensor_shape.TensorShapeV1,
tensor_shape.TensorShapeV2))
return isinstance(pyobj, tensor_shape.TensorShape)
def do_encode(self, tensor_shape_value, encode_fn):
del encode_fn

View File

@ -1,6 +1,6 @@
path: "tensorflow.TensorShape"
tf_class {
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV1\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShape\'>"
is_instance: "<type \'object\'>"
member {
name: "dims"

View File

@ -1,7 +1,6 @@
path: "tensorflow.TensorShape"
tf_class {
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV2\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV1\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShape\'>"
is_instance: "<type \'object\'>"
member {
name: "dims"