Simplify TensorShape v1/v2 switch (and fix bugs in it). Note that this means that whenever v2 behavior is enabled, all functions (including compat.v1 functions) will return
v2 TensorShapes. TensorShape behavior is switched as a global, and not per function. PiperOrigin-RevId: 233894962
This commit is contained in:
parent
368674dfe2
commit
5979904503
tensorflow
python
data/experimental/ops
framework
ops
saved_model
tools/api/golden
@ -718,15 +718,15 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
"""Recalculates the output_shapes after dividing it by num_workers."""
|
||||
if len(output_shapes) < 1:
|
||||
raise ValueError("Input shape should have at least one dimension.")
|
||||
if (output_shapes.dims[0].value and
|
||||
output_shapes.dims[0].value % num_workers != 0):
|
||||
if (tensor_shape.dimension_value(output_shapes[0]) and
|
||||
tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
|
||||
raise errors.InvalidArgumentError(
|
||||
None, None,
|
||||
"First dim of input shape: %d is not divisible by num_workers: %d" %
|
||||
(output_shapes[0], num_workers))
|
||||
output_dims = [d for d in output_shapes.dims]
|
||||
output_dims[0] = output_dims[0] // num_workers
|
||||
return tensor_shape.TensorShapeV1(output_dims)
|
||||
return tensor_shape.TensorShape(output_dims)
|
||||
|
||||
output_shapes = nest.map_structure(recalculate_output_shapes,
|
||||
input_dataset.output_shapes)
|
||||
|
@ -74,9 +74,8 @@ def enable_v2_tensorshape():
|
||||
# in `tensor_shape[i]`, but they would not be.
|
||||
```
|
||||
"""
|
||||
global _TENSORSHAPE_V2_OVERRIDE, TensorShape # pylint: disable=invalid-name
|
||||
global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name
|
||||
_TENSORSHAPE_V2_OVERRIDE = True
|
||||
TensorShape = TensorShapeV2
|
||||
|
||||
|
||||
@tf_export(v1=["disable_v2_tensorshape"])
|
||||
@ -85,9 +84,8 @@ def disable_v2_tensorshape():
|
||||
|
||||
See docstring for `enable_v2_tensorshape` for details about the new behavior.
|
||||
"""
|
||||
global _TENSORSHAPE_V2_OVERRIDE, TensorShape # pylint: disable=invalid-name
|
||||
global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name
|
||||
_TENSORSHAPE_V2_OVERRIDE = False
|
||||
TensorShape = TensorShapeV1
|
||||
|
||||
|
||||
@tf_export("compat.dimension_value",
|
||||
@ -635,8 +633,8 @@ def as_dimension(value):
|
||||
return Dimension(value)
|
||||
|
||||
|
||||
@tf_export(v1=["TensorShape"])
|
||||
class TensorShapeV1(object):
|
||||
@tf_export("TensorShape")
|
||||
class TensorShape(object):
|
||||
"""Represents the shape of a `Tensor`.
|
||||
|
||||
A `TensorShape` represents a possibly-partial shape specification for a
|
||||
@ -695,7 +693,7 @@ class TensorShapeV1(object):
|
||||
@property
|
||||
def _v2_behavior(self):
|
||||
if _TENSORSHAPE_V2_OVERRIDE is None:
|
||||
return False
|
||||
return tf2.enabled()
|
||||
return _TENSORSHAPE_V2_OVERRIDE
|
||||
|
||||
def __repr__(self):
|
||||
@ -1151,22 +1149,6 @@ def unknown_shape(rank=None, **kwargs):
|
||||
return TensorShape([Dimension(None)] * rank)
|
||||
|
||||
|
||||
@tf_export("TensorShape", v1=[])
|
||||
class TensorShapeV2(TensorShapeV1):
|
||||
|
||||
@property
|
||||
def _v2_behavior(self):
|
||||
if _TENSORSHAPE_V2_OVERRIDE is None:
|
||||
return True
|
||||
return _TENSORSHAPE_V2_OVERRIDE
|
||||
|
||||
|
||||
if tf2.enabled():
|
||||
TensorShape = TensorShapeV2
|
||||
else:
|
||||
TensorShape = TensorShapeV1
|
||||
|
||||
|
||||
def scalar():
|
||||
"""Returns a shape representing a scalar."""
|
||||
return TensorShape([])
|
||||
|
@ -928,16 +928,16 @@ def convolution_internal(
|
||||
name=None):
|
||||
"""Internal function which performs rank agnostic convolution."""
|
||||
with ops.name_scope(name, "convolution", [input, filters]) as name:
|
||||
if isinstance(input.shape, tensor_shape.TensorShapeV1) and \
|
||||
if isinstance(input.shape, tensor_shape.TensorShape) and \
|
||||
input.shape.rank is not None:
|
||||
n = len(input.shape) - 2
|
||||
elif not isinstance(input.shape, tensor_shape.TensorShapeV1) and \
|
||||
elif not isinstance(input.shape, tensor_shape.TensorShape) and \
|
||||
input.shape is not None:
|
||||
n = len(input.shape) - 2
|
||||
elif isinstance(filters.shape, tensor_shape.TensorShapeV1) and \
|
||||
elif isinstance(filters.shape, tensor_shape.TensorShape) and \
|
||||
filters.shape.rank is not None:
|
||||
n = len(filters.shape) - 2
|
||||
elif not isinstance(filters.shape, tensor_shape.TensorShapeV1) and \
|
||||
elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
|
||||
filters.shape is not None:
|
||||
n = len(filters.shape) - 2
|
||||
else:
|
||||
|
@ -361,10 +361,7 @@ class _TensorShapeCodec(object):
|
||||
"""Codec for `TensorShape`."""
|
||||
|
||||
def can_encode(self, pyobj):
|
||||
return isinstance(pyobj, (tensor_shape.TensorShape,
|
||||
# TODO(b/121255889): Should not need these.
|
||||
tensor_shape.TensorShapeV1,
|
||||
tensor_shape.TensorShapeV2))
|
||||
return isinstance(pyobj, tensor_shape.TensorShape)
|
||||
|
||||
def do_encode(self, tensor_shape_value, encode_fn):
|
||||
del encode_fn
|
||||
|
@ -1,6 +1,6 @@
|
||||
path: "tensorflow.TensorShape"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShape\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dims"
|
||||
|
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.TensorShape"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShape\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dims"
|
||||
|
Loading…
Reference in New Issue
Block a user