diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index 9c79af24649..983f7640b89 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -718,15 +718,15 @@ class _RebatchDataset(dataset_ops.UnaryDataset): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") - if (output_shapes.dims[0].value and - output_shapes.dims[0].value % num_workers != 0): + if (tensor_shape.dimension_value(output_shapes[0]) and + tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers - return tensor_shape.TensorShapeV1(output_dims) + return tensor_shape.TensorShape(output_dims) output_shapes = nest.map_structure(recalculate_output_shapes, input_dataset.output_shapes) diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index a7537bb5f1a..40fccc86a34 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -74,9 +74,8 @@ def enable_v2_tensorshape(): # in `tensor_shape[i]`, but they would not be. ``` """ - global _TENSORSHAPE_V2_OVERRIDE, TensorShape # pylint: disable=invalid-name + global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name _TENSORSHAPE_V2_OVERRIDE = True - TensorShape = TensorShapeV2 @tf_export(v1=["disable_v2_tensorshape"]) @@ -85,9 +84,8 @@ def disable_v2_tensorshape(): See docstring for `enable_v2_tensorshape` for details about the new behavior. """ - global _TENSORSHAPE_V2_OVERRIDE, TensorShape # pylint: disable=invalid-name + global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name _TENSORSHAPE_V2_OVERRIDE = False - TensorShape = TensorShapeV1 @tf_export("compat.dimension_value", @@ -635,8 +633,8 @@ def as_dimension(value): return Dimension(value) -@tf_export(v1=["TensorShape"]) -class TensorShapeV1(object): +@tf_export("TensorShape") +class TensorShape(object): """Represents the shape of a `Tensor`. A `TensorShape` represents a possibly-partial shape specification for a @@ -695,7 +693,7 @@ class TensorShapeV1(object): @property def _v2_behavior(self): if _TENSORSHAPE_V2_OVERRIDE is None: - return False + return tf2.enabled() return _TENSORSHAPE_V2_OVERRIDE def __repr__(self): @@ -1151,22 +1149,6 @@ def unknown_shape(rank=None, **kwargs): return TensorShape([Dimension(None)] * rank) -@tf_export("TensorShape", v1=[]) -class TensorShapeV2(TensorShapeV1): - - @property - def _v2_behavior(self): - if _TENSORSHAPE_V2_OVERRIDE is None: - return True - return _TENSORSHAPE_V2_OVERRIDE - - -if tf2.enabled(): - TensorShape = TensorShapeV2 -else: - TensorShape = TensorShapeV1 - - def scalar(): """Returns a shape representing a scalar.""" return TensorShape([]) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index a7aa97c2e4f..cf4aa51b6ed 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -928,16 +928,16 @@ def convolution_internal( name=None): """Internal function which performs rank agnostic convolution.""" with ops.name_scope(name, "convolution", [input, filters]) as name: - if isinstance(input.shape, tensor_shape.TensorShapeV1) and \ + if isinstance(input.shape, tensor_shape.TensorShape) and \ input.shape.rank is not None: n = len(input.shape) - 2 - elif not isinstance(input.shape, tensor_shape.TensorShapeV1) and \ + elif not isinstance(input.shape, tensor_shape.TensorShape) and \ input.shape is not None: n = len(input.shape) - 2 - elif isinstance(filters.shape, tensor_shape.TensorShapeV1) and \ + elif isinstance(filters.shape, tensor_shape.TensorShape) and \ filters.shape.rank is not None: n = len(filters.shape) - 2 - elif not isinstance(filters.shape, tensor_shape.TensorShapeV1) and \ + elif not isinstance(filters.shape, tensor_shape.TensorShape) and \ filters.shape is not None: n = len(filters.shape) - 2 else: diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py index 5cf9a5b155b..3d335de5559 100644 --- a/tensorflow/python/saved_model/nested_structure_coder.py +++ b/tensorflow/python/saved_model/nested_structure_coder.py @@ -361,10 +361,7 @@ class _TensorShapeCodec(object): """Codec for `TensorShape`.""" def can_encode(self, pyobj): - return isinstance(pyobj, (tensor_shape.TensorShape, - # TODO(b/121255889): Should not need these. - tensor_shape.TensorShapeV1, - tensor_shape.TensorShapeV2)) + return isinstance(pyobj, tensor_shape.TensorShape) def do_encode(self, tensor_shape_value, encode_fn): del encode_fn diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt index d11e927bd55..60518ffadc8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.TensorShape" tf_class { - is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV1\'>" + is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShape\'>" is_instance: "<type \'object\'>" member { name: "dims" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt index bee19520b77..60518ffadc8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt @@ -1,7 +1,6 @@ path: "tensorflow.TensorShape" tf_class { - is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV2\'>" - is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShapeV1\'>" + is_instance: "<class \'tensorflow.python.framework.tensor_shape.TensorShape\'>" is_instance: "<type \'object\'>" member { name: "dims"