Inlined tensor_shape.{scalar,vector,matrix}

Explicit constructor call is no less clear and match what we export via
the public API.

The functions will be removed once all the internal users are migrated.

PiperOrigin-RevId: 259620054
This commit is contained in:
Sergei Lebedev 2019-07-23 15:07:01 -07:00 committed by TensorFlower Gardener
parent 2ed843260a
commit bca5e7385f
63 changed files with 223 additions and 221 deletions

View File

@ -476,7 +476,7 @@ class BigtableTable(object):
if tensor_type != dtypes.string: if tensor_type != dtypes.string:
raise ValueError("Not all elements of the dataset were `tf.string`") raise ValueError("Not all elements of the dataset were `tf.string`")
for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)): for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
if not shape.is_compatible_with(tensor_shape.scalar()): if not shape.is_compatible_with(tensor_shape.TensorShape([])):
raise ValueError("Not all elements of the dataset were scalars") raise ValueError("Not all elements of the dataset were scalars")
if len(column_families) != len(columns): if len(column_families) != len(columns):
raise ValueError("len(column_families) != len(columns)") raise ValueError("len(column_families) != len(columns)")

View File

@ -60,8 +60,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [0, 1], [2, 0], [3, 0]] indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = categorical_split_handler.EqualitySplitHandler( split_handler = categorical_split_handler.EqualitySplitHandler(
@ -183,8 +183,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [1, 0], [2, 0], [3, 0]] indices = [[0, 0], [1, 0], [2, 0], [3, 0]]
values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64) values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = categorical_split_handler.EqualitySplitHandler( split_handler = categorical_split_handler.EqualitySplitHandler(
@ -294,8 +294,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [0, 1], [2, 0], [3, 0]] indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = categorical_split_handler.EqualitySplitHandler( split_handler = categorical_split_handler.EqualitySplitHandler(
@ -489,8 +489,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2])
values = constant_op.constant_v1([], dtype=dtypes.int64) values = constant_op.constant_v1([], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = categorical_split_handler.EqualitySplitHandler( split_handler = categorical_split_handler.EqualitySplitHandler(
@ -537,8 +537,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [0, 1], [2, 0], [3, 0]] indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = categorical_split_handler.EqualitySplitHandler( split_handler = categorical_split_handler.EqualitySplitHandler(
@ -591,8 +591,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [0, 1], [2, 0]] indices = [[0, 0], [0, 1], [2, 0]]
values = array_ops.constant([1, 2, 2], dtype=dtypes.int64) values = array_ops.constant([1, 2, 2], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = categorical_split_handler.EqualitySplitHandler( split_handler = categorical_split_handler.EqualitySplitHandler(

View File

@ -75,7 +75,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -261,8 +260,7 @@ class DenseSplitHandler(InequalitySplitHandler):
def make_splits(self, stamp_token, next_stamp_token, class_id): def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state.""" """Create the best split using the accumulated stats and flush the state."""
if (self._gradient_shape == tensor_shape.scalar() and if (self._gradient_shape.rank == 0 and self._hessian_shape.rank == 0):
self._hessian_shape == tensor_shape.scalar()):
handler = make_dense_split_scalar handler = make_dense_split_scalar
else: else:
handler = make_dense_split_tensor handler = make_dense_split_tensor
@ -441,8 +439,7 @@ class SparseSplitHandler(InequalitySplitHandler):
def make_splits(self, stamp_token, next_stamp_token, class_id): def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state.""" """Create the best split using the accumulated stats and flush the state."""
if (self._gradient_shape == tensor_shape.scalar() and if self._gradient_shape.rank == 0 and self._hessian_shape.rank == 0:
self._hessian_shape == tensor_shape.scalar()):
handler = make_sparse_split_scalar handler = make_sparse_split_scalar
else: else:
handler = make_sparse_split_tensor handler = make_sparse_split_tensor

View File

@ -63,8 +63,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
class_id = -1 class_id = -1
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
split_handler = ordinal_split_handler.DenseSplitHandler( split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1, l1_regularization=0.1,
l2_regularization=1., l2_regularization=1.,
@ -197,8 +197,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32) partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32)
class_id = -1 class_id = -1
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
split_handler = ordinal_split_handler.DenseSplitHandler( split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1, l1_regularization=0.1,
l2_regularization=1., l2_regularization=1.,
@ -333,8 +333,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
class_id = -1 class_id = -1
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
split_handler = ordinal_split_handler.DenseSplitHandler( split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.2, l1_regularization=0.2,
l2_regularization=2., l2_regularization=2.,
@ -645,8 +645,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.DenseSplitHandler( split_handler = ordinal_split_handler.DenseSplitHandler(
@ -720,8 +720,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.DenseSplitHandler( split_handler = ordinal_split_handler.DenseSplitHandler(
@ -854,8 +854,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessians = array_ops.constant([0.12, 0.07, 0.2, 2]) hessians = array_ops.constant([0.12, 0.07, 0.2, 2])
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.DenseSplitHandler( split_handler = ordinal_split_handler.DenseSplitHandler(
@ -965,8 +965,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
values = array_ops.constant([0.52, 0.3, 0.52]) values = array_ops.constant([0.52, 0.3, 0.52])
sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1])
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler( split_handler = ordinal_split_handler.SparseSplitHandler(
@ -1088,8 +1088,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
values = array_ops.constant([0.52, 0.3, 0.52]) values = array_ops.constant([0.52, 0.3, 0.52])
sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1])
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler( split_handler = ordinal_split_handler.SparseSplitHandler(
@ -1411,8 +1411,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
values = array_ops.constant([0.52, 0.3, 0.52]) values = array_ops.constant([0.52, 0.3, 0.52])
sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1])
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler( split_handler = ordinal_split_handler.SparseSplitHandler(
@ -1481,8 +1481,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
values = constant_op.constant_v1([], dtype=dtypes.float32) values = constant_op.constant_v1([], dtype=dtypes.float32)
sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1])
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler( split_handler = ordinal_split_handler.SparseSplitHandler(
@ -1565,8 +1565,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
non_empty_indices, non_empty_values, [4, 2]) non_empty_indices, non_empty_values, [4, 2])
non_empty_sparse_column = non_empty_sparse_column.eval(session=sess) non_empty_sparse_column = non_empty_sparse_column.eval(session=sess)
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler( split_handler = ordinal_split_handler.SparseSplitHandler(
@ -1650,8 +1650,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
values = array_ops.constant([0.58]) values = array_ops.constant([0.58])
sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1]) sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1])
gradient_shape = tensor_shape.scalar() gradient_shape = tensor_shape.TensorShape([])
hessian_shape = tensor_shape.scalar() hessian_shape = tensor_shape.TensorShape([])
class_id = -1 class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler( split_handler = ordinal_split_handler.SparseSplitHandler(

View File

@ -32,8 +32,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator( accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar()) hessian_shape=tensor_shape.TensorShape([]))
with ops.control_dependencies([accumulator.initializer]): with ops.control_dependencies([accumulator.initializer]):
op1 = accumulator.add( op1 = accumulator.add(
stamp_token=0, stamp_token=0,
@ -60,8 +60,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator( accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar()) hessian_shape=tensor_shape.TensorShape([]))
with ops.control_dependencies([accumulator.initializer]): with ops.control_dependencies([accumulator.initializer]):
op1 = accumulator.add( op1 = accumulator.add(
stamp_token=0, stamp_token=0,
@ -89,8 +89,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator( accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar()) hessian_shape=tensor_shape.TensorShape([]))
with ops.control_dependencies([accumulator.initializer]): with ops.control_dependencies([accumulator.initializer]):
op1 = accumulator.add( op1 = accumulator.add(
stamp_token=0, stamp_token=0,
@ -121,8 +121,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator( accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar()) hessian_shape=tensor_shape.TensorShape([]))
with ops.control_dependencies([accumulator.initializer]): with ops.control_dependencies([accumulator.initializer]):
op1 = accumulator.add( op1 = accumulator.add(
stamp_token=0, stamp_token=0,
@ -162,8 +162,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator( accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar()) hessian_shape=tensor_shape.TensorShape([]))
with ops.control_dependencies([accumulator.initializer]): with ops.control_dependencies([accumulator.initializer]):
# These will be deleted due to deserialize call. # These will be deleted due to deserialize call.
op1 = accumulator.add( op1 = accumulator.add(
@ -199,8 +199,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator( accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar()) hessian_shape=tensor_shape.TensorShape([]))
partition, feature, grads, hessians = accumulator._make_summary( partition, feature, grads, hessians = accumulator._make_summary(
partition_ids=[1, 2, 1], partition_ids=[1, 2, 1],
feature_ids=[[2, 0], [3, 1], [2, 0]], feature_ids=[[2, 0], [3, 1], [2, 0]],

View File

@ -25,7 +25,6 @@ import six
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -65,7 +64,7 @@ def _move_tensors(tensors, device):
# logic. # logic.
zero = constant_op.constant(0, dtype=dtypes.int32) zero = constant_op.constant(0, dtype=dtypes.int32)
with ops.device(None): with ops.device(None):
if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): if all(tensor.shape.rank == 0 for tensor in tensors):
with ops.device(tensors[0].device): with ops.device(tensors[0].device):
values = array_ops.stack(tensors) values = array_ops.stack(tensors)
with ops.device(device): with ops.device(device):

View File

@ -23,7 +23,6 @@ from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader
# pylint: enable=unused-import # pylint: enable=unused-import
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import resources from tensorflow.python.ops import resources
from tensorflow.python.training import saver from tensorflow.python.training import saver
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
@ -134,8 +133,7 @@ class StatsAccumulator(tracking.TrackableResource):
self._hessian_shape = hessian_shape self._hessian_shape = hessian_shape
self._container = container self._container = container
if (gradient_shape == tensor_shape.scalar() and if (gradient_shape.rank == 0 and hessian_shape.rank == 0):
hessian_shape == tensor_shape.scalar()):
self._is_scalar = True self._is_scalar = True
else: else:
self._is_scalar = False self._is_scalar = False

View File

@ -368,8 +368,8 @@ class GradientBoostedDecisionTreeModel(object):
if logits_dimension == 1 or learner_config.multi_class_strategy == ( if logits_dimension == 1 or learner_config.multi_class_strategy == (
learner_pb2.LearnerConfig.TREE_PER_CLASS): learner_pb2.LearnerConfig.TREE_PER_CLASS):
self._gradient_shape = tensor_shape.scalar() self._gradient_shape = tensor_shape.TensorShape([])
self._hessian_shape = tensor_shape.scalar() self._hessian_shape = tensor_shape.TensorShape([])
else: else:
if center_bias: if center_bias:
raise ValueError("Center bias should be False for multiclass.") raise ValueError("Center bias should be False for multiclass.")
@ -838,8 +838,8 @@ class GradientBoostedDecisionTreeModel(object):
# Create steps accumulator. # Create steps accumulator.
steps_accumulator = stats_accumulator_ops.StatsAccumulator( steps_accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0, stamp_token=0,
gradient_shape=tensor_shape.scalar(), gradient_shape=tensor_shape.TensorShape([]),
hessian_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.TensorShape([]),
name="StepsAccumulator") name="StepsAccumulator")
# Create ensemble stats summaries. # Create ensemble stats summaries.
summary.scalar("layer_stats/num_examples", num_layer_examples) summary.scalar("layer_stats/num_examples", num_layer_examples)
@ -1212,7 +1212,7 @@ class GradientBoostedDecisionTreeModel(object):
def _get_weights(self, hessian_shape, hessians): def _get_weights(self, hessian_shape, hessians):
"""Derives weights to be used based on hessians and multiclass strategy.""" """Derives weights to be used based on hessians and multiclass strategy."""
if hessian_shape == tensor_shape.scalar(): if hessian_shape.rank == 0:
# This is tree per class. # This is tree per class.
weights = hessians weights = hessians
elif len(hessian_shape.dims) == 1: elif len(hessian_shape.dims) == 1:

View File

@ -191,10 +191,8 @@ class BatchReshape(distribution_lib.Distribution):
self.distribution.survival_function, x) self.distribution.survival_function, x)
def _entropy(self): def _entropy(self):
return self._call_and_reshape_output( return self._call_and_reshape_output(self.distribution.entropy, [],
self.distribution.entropy, [tensor_shape.TensorShape([])])
[],
[tensor_shape.scalar()])
def _mean(self): def _mean(self):
return self._call_and_reshape_output(self.distribution.mean) return self._call_and_reshape_output(self.distribution.mean)

View File

@ -230,7 +230,7 @@ class Binomial(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
@distribution_util.AppendDocstring(_binomial_sample_note) @distribution_util.AppendDocstring(_binomial_sample_note)
def _log_prob(self, counts): def _log_prob(self, counts):

View File

@ -173,7 +173,7 @@ class Cauchy(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)

View File

@ -281,7 +281,7 @@ class Deterministic(_BaseDeterministic):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _prob(self, x): def _prob(self, x):
return math_ops.cast( return math_ops.cast(

View File

@ -132,7 +132,7 @@ class Geometric(distribution.Distribution):
return array_ops.constant([], dtype=dtypes.int32) return array_ops.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
# Uniform variates must be sampled from the open-interval `(0, 1)` rather # Uniform variates must be sampled from the open-interval `(0, 1)` rather

View File

@ -178,7 +178,7 @@ class _Gumbel(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
# Uniform variates must be sampled from the open-interval `(0, 1)` rather # Uniform variates must be sampled from the open-interval `(0, 1)` rather

View File

@ -150,7 +150,7 @@ class HalfNormal(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)

View File

@ -187,7 +187,7 @@ class InverseGamma(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
@distribution_util.AppendDocstring( @distribution_util.AppendDocstring(
"""Note: See `tf.random.gamma` docstring for sampling details and """Note: See `tf.random.gamma` docstring for sampling details and

View File

@ -173,7 +173,7 @@ class Logistic(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
# Uniform variates must be sampled from the open-interval `(0, 1)` rather # Uniform variates must be sampled from the open-interval `(0, 1)` rather

View File

@ -145,7 +145,7 @@ class NegativeBinomial(distribution.Distribution):
return array_ops.constant([], dtype=dtypes.int32) return array_ops.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
# Here we use the fact that if: # Here we use the fact that if:

View File

@ -151,7 +151,7 @@ class Poisson(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
@distribution_util.AppendDocstring(_poisson_sample_note) @distribution_util.AppendDocstring(_poisson_sample_note)
def _log_prob(self, x): def _log_prob(self, x):

View File

@ -355,7 +355,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
self.mixture_distribution.logits.shape)[:-1] self.mixture_distribution.logits.shape)[:-1]
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
# Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get

View File

@ -162,7 +162,7 @@ class ModelFnOps(
loss_shape = loss.get_shape() loss_shape = loss.get_shape()
if loss_shape.num_elements() not in (None, 1): if loss_shape.num_elements() not in (None, 1):
raise ValueError('Loss must be scalar: %s.' % loss) raise ValueError('Loss must be scalar: %s.' % loss)
if not loss_shape.is_compatible_with(tensor_shape.scalar()): if not loss_shape.is_compatible_with(tensor_shape.TensorShape([])):
loss = array_ops.reshape(loss, []) loss = array_ops.reshape(loss, [])
# Validate predictions. # Validate predictions.

View File

@ -19,12 +19,11 @@ from __future__ import print_function
import numbers import numbers
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
@ -61,7 +60,7 @@ def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylin
keep_prob = ops.convert_to_tensor(keep_prob, keep_prob = ops.convert_to_tensor(keep_prob,
dtype=x.dtype, dtype=x.dtype,
name="keep_prob") name="keep_prob")
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) keep_prob.get_shape().assert_has_rank(0)
# Do nothing if we know keep_prob == 1 # Do nothing if we know keep_prob == 1
if tensor_util.constant_value(keep_prob) == 1: if tensor_util.constant_value(keep_prob) == 1:

View File

@ -144,14 +144,16 @@ class ParallelReaderTest(test.TestCase):
capacity=55, capacity=55,
min_after_dequeue=28, min_after_dequeue=28,
dtypes=[dtypes_lib.string, dtypes_lib.string], dtypes=[dtypes_lib.string, dtypes_lib.string],
shapes=[tensor_shape.scalar(), tensor_shape.scalar()]) shapes=[tensor_shape.TensorShape([]),
tensor_shape.TensorShape([])])
self._verify_read_up_to_out(shared_queue) self._verify_read_up_to_out(shared_queue)
def testReadUpToFromFIFOQueue(self): def testReadUpToFromFIFOQueue(self):
shared_queue = data_flow_ops.FIFOQueue( shared_queue = data_flow_ops.FIFOQueue(
capacity=99, capacity=99,
dtypes=[dtypes_lib.string, dtypes_lib.string], dtypes=[dtypes_lib.string, dtypes_lib.string],
shapes=[tensor_shape.scalar(), tensor_shape.scalar()]) shapes=[tensor_shape.TensorShape([]),
tensor_shape.TensorShape([])])
self._verify_read_up_to_out(shared_queue) self._verify_read_up_to_out(shared_queue)

View File

@ -212,7 +212,7 @@ def bucket(tensors,
else static_batch_size) else static_batch_size)
bucket_shapes = [ bucket_shapes = [
tensor_shape.vector(maybe_static_batch_size).concatenate(s) tensor_shape.TensorShape([maybe_static_batch_size]).concatenate(s)
for s in bucket_queues[0].shapes for s in bucket_queues[0].shapes
] ]
# top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO
@ -222,7 +222,7 @@ def bucket(tensors,
top_queue = data_flow_ops.PaddingFIFOQueue( top_queue = data_flow_ops.PaddingFIFOQueue(
capacity=capacity, capacity=capacity,
dtypes=[dtypes.int32] + types, dtypes=[dtypes.int32] + types,
shapes=[tensor_shape.scalar()] + bucket_shapes, shapes=[tensor_shape.TensorShape([])] + bucket_shapes,
shared_name=shared_name, shared_name=shared_name,
name="top_queue") name="top_queue")
@ -403,7 +403,7 @@ def bucket_by_sequence_length(input_length,
which_bucket = math_ops.cast(which_bucket, dtypes.int32) which_bucket = math_ops.cast(which_bucket, dtypes.int32)
if shapes is not None: if shapes is not None:
shapes = [tensor_shape.scalar()] + shapes shapes = [tensor_shape.TensorShape([])] + shapes
_, dequeued = bucket( _, dequeued = bucket(
tensors=[input_length] + tensor_list, tensors=[input_length] + tensor_list,

View File

@ -46,7 +46,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda x: x % 2, reducer)) grouping.group_by_reducer(lambda x: x % 2, reducer))
self.assertDatasetProduces( self.assertDatasetProduces(
dataset, dataset,
expected_shapes=tensor_shape.scalar(), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[(i - 1) * i, i * i]) expected_output=[(i - 1) * i, i * i])
def testAverage(self): def testAverage(self):
@ -65,7 +65,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
self.assertDatasetProduces( self.assertDatasetProduces(
dataset, dataset,
expected_shapes=tensor_shape.scalar(), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[i - 1, i]) expected_output=[i - 1, i])
def testConcat(self): def testConcat(self):
@ -81,8 +81,8 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda x, y: y % 2, reducer)) grouping.group_by_reducer(lambda x, y: y % 2, reducer))
self.assertDatasetProduces( self.assertDatasetProduces(
dataset, dataset,
expected_shapes=tensor_shape.scalar(), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) expected_output=[b"acegikmoqs"[:i], b"bdfhjlnprt"[:i]])
def testSparseSum(self): def testSparseSum(self):
def _sparse(i): def _sparse(i):
@ -100,7 +100,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
self.assertDatasetProduces( self.assertDatasetProduces(
dataset, dataset,
expected_shapes=tensor_shape.scalar(), expected_shapes=tensor_shape.TensorShape([]),
expected_output=[(i - 1) * i, i * i]) expected_output=[(i - 1) * i, i * i])
def testChangingStateShape(self): def testChangingStateShape(self):

View File

@ -244,7 +244,7 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
self._batch_size = batch_size self._batch_size = batch_size
self._row_shape = row_shape self._row_shape = row_shape
self._element_spec = sparse_tensor.SparseTensorSpec( self._element_spec = sparse_tensor.SparseTensorSpec(
tensor_shape.vector(None).concatenate(self._row_shape), tensor_shape.TensorShape([None]).concatenate(self._row_shape),
dataset_ops.get_legacy_output_types(input_dataset)) dataset_ops.get_legacy_output_types(input_dataset))
if compat.forward_compatible(2019, 8, 3): if compat.forward_compatible(2019, 8, 3):

View File

@ -142,7 +142,7 @@ class DatasetCheckpointTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
# Create an empty IteratorResource and restore the Iterator into it. # Create an empty IteratorResource and restore the Iterator into it.
output_types = dtypes.int64 output_types = dtypes.int64
output_shapes = tensor_shape.scalar() output_shapes = tensor_shape.TensorShape([])
iterator = iterator_ops.Iterator.from_structure(output_types, iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes) output_shapes)
restore_op = self._restore_op(iterator._iterator_resource) restore_op = self._restore_op(iterator._iterator_resource)

View File

@ -287,7 +287,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset_ops.get_structure(dataset), expected_element_structure)) dataset_ops.get_structure(dataset), expected_element_structure))
self.assertEqual([dtypes.variant], self.assertEqual([dtypes.variant],
structure.get_flat_tensor_types(dataset_structure)) structure.get_flat_tensor_types(dataset_structure))
self.assertEqual([tensor_shape.scalar()], self.assertEqual([tensor_shape.TensorShape([])],
structure.get_flat_tensor_shapes(dataset_structure)) structure.get_flat_tensor_shapes(dataset_structure))
# Assert that the `Dataset` survives a round-trip via _from_tensor_list() # Assert that the `Dataset` survives a round-trip via _from_tensor_list()

View File

@ -290,7 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_value_structure)) expected_value_structure))
self.assertEqual([dtypes.variant], self.assertEqual([dtypes.variant],
structure.get_flat_tensor_types(opt_structure)) structure.get_flat_tensor_types(opt_structure))
self.assertEqual([tensor_shape.scalar()], self.assertEqual([tensor_shape.TensorShape([])],
structure.get_flat_tensor_shapes(opt_structure)) structure.get_flat_tensor_shapes(opt_structure))
# All OptionalSpec objects are not compatible with a non-optional # All OptionalSpec objects are not compatible with a non-optional

View File

@ -3165,7 +3165,7 @@ def _padding_value_to_tensor(value, output_type):
TypeError: if the padding value's type does not match `output_type`. TypeError: if the padding value's type does not match `output_type`.
""" """
value = ops.convert_to_tensor(value, name="padding_value") value = ops.convert_to_tensor(value, name="padding_value")
if not value.shape.is_compatible_with(tensor_shape.scalar()): if not value.shape.is_compatible_with(tensor_shape.TensorShape([])):
raise ValueError("Padding value should be a scalar, but is not: %s" % value) raise ValueError("Padding value should be a scalar, but is not: %s" % value)
if value.dtype != output_type: if value.dtype != output_type:
raise TypeError("Padding value tensor (%s) does not match output type: %s" % raise TypeError("Padding value tensor (%s) does not match output type: %s" %
@ -3229,10 +3229,10 @@ class PaddedBatchDataset(UnaryDataset):
drop_remainder, dtype=dtypes.bool, name="drop_remainder") drop_remainder, dtype=dtypes.bool, name="drop_remainder")
def _padded_shape_to_batch_shape(s): def _padded_shape_to_batch_shape(s):
return tensor_shape.vector( return tensor_shape.TensorShape([
tensor_util.constant_value(self._batch_size) if smart_cond. tensor_util.constant_value(self._batch_size)
smart_constant_value(self._drop_remainder) else None).concatenate( if smart_cond.smart_constant_value(self._drop_remainder) else None
tensor_util.constant_value_as_shape(s)) ]).concatenate(tensor_util.constant_value_as_shape(s))
output_shapes = nest.map_structure( output_shapes = nest.map_structure(
_padded_shape_to_batch_shape, self._padded_shapes) _padded_shape_to_batch_shape, self._padded_shapes)

View File

@ -53,7 +53,7 @@ def _create_or_validate_filenames_dataset(filenames):
raise TypeError( raise TypeError(
"`filenames` must be a `tf.data.Dataset` of `tf.string` elements.") "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with( if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
tensor_shape.scalar()): tensor_shape.TensorShape([])):
raise TypeError( raise TypeError(
"`filenames` must be a `tf.data.Dataset` of scalar `tf.string` " "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
"elements.") "elements.")

View File

@ -87,64 +87,67 @@ class SparseTest(test.TestCase):
"expected": () "expected": ()
}, },
{ {
"types": tensor_shape.scalar(), "types": tensor_shape.TensorShape([]),
"classes": ops.Tensor, "classes": ops.Tensor,
"expected": tensor_shape.scalar() "expected": tensor_shape.TensorShape([])
}, },
{ {
"types": tensor_shape.scalar(), "types": tensor_shape.TensorShape([]),
"classes": sparse_tensor.SparseTensor, "classes": sparse_tensor.SparseTensor,
"expected": tensor_shape.unknown_shape() "expected": tensor_shape.unknown_shape()
}, },
{ {
"types": (tensor_shape.scalar()), "types": (tensor_shape.TensorShape([])),
"classes": (ops.Tensor), "classes": (ops.Tensor),
"expected": (tensor_shape.scalar()) "expected": (tensor_shape.TensorShape([]))
}, },
{ {
"types": (tensor_shape.scalar()), "types": (tensor_shape.TensorShape([])),
"classes": (sparse_tensor.SparseTensor), "classes": (sparse_tensor.SparseTensor),
"expected": (tensor_shape.unknown_shape()) "expected": (tensor_shape.unknown_shape())
}, },
{ {
"types": (tensor_shape.scalar(), ()), "types": (tensor_shape.TensorShape([]), ()),
"classes": (ops.Tensor, ()), "classes": (ops.Tensor, ()),
"expected": (tensor_shape.scalar(), ()) "expected": (tensor_shape.TensorShape([]), ())
}, },
{ {
"types": ((), tensor_shape.scalar()), "types": ((), tensor_shape.TensorShape([])),
"classes": ((), ops.Tensor), "classes": ((), ops.Tensor),
"expected": ((), tensor_shape.scalar()) "expected": ((), tensor_shape.TensorShape([]))
}, },
{ {
"types": (tensor_shape.scalar(), ()), "types": (tensor_shape.TensorShape([]), ()),
"classes": (sparse_tensor.SparseTensor, ()), "classes": (sparse_tensor.SparseTensor, ()),
"expected": (tensor_shape.unknown_shape(), ()) "expected": (tensor_shape.unknown_shape(), ())
}, },
{ {
"types": ((), tensor_shape.scalar()), "types": ((), tensor_shape.TensorShape([])),
"classes": ((), sparse_tensor.SparseTensor), "classes": ((), sparse_tensor.SparseTensor),
"expected": ((), tensor_shape.unknown_shape()) "expected": ((), tensor_shape.unknown_shape())
}, },
{ {
"types": (tensor_shape.scalar(), (), tensor_shape.scalar()), "types": (tensor_shape.TensorShape([]), (),
tensor_shape.TensorShape([])),
"classes": (ops.Tensor, (), ops.Tensor), "classes": (ops.Tensor, (), ops.Tensor),
"expected": (tensor_shape.scalar(), (), tensor_shape.scalar()) "expected": (tensor_shape.TensorShape([]), (),
tensor_shape.TensorShape([]))
}, },
{ {
"types": (tensor_shape.scalar(), (), tensor_shape.scalar()), "types": (tensor_shape.TensorShape([]), (),
"classes": (sparse_tensor.SparseTensor, (), tensor_shape.TensorShape([])),
sparse_tensor.SparseTensor), "classes":
(sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor),
"expected": (tensor_shape.unknown_shape(), (), "expected": (tensor_shape.unknown_shape(), (),
tensor_shape.unknown_shape()) tensor_shape.unknown_shape())
}, },
{ {
"types": ((), tensor_shape.scalar(), ()), "types": ((), tensor_shape.TensorShape([]), ()),
"classes": ((), ops.Tensor, ()), "classes": ((), ops.Tensor, ()),
"expected": ((), tensor_shape.scalar(), ()) "expected": ((), tensor_shape.TensorShape([]), ())
}, },
{ {
"types": ((), tensor_shape.scalar(), ()), "types": ((), tensor_shape.TensorShape([]), ()),
"classes": ((), sparse_tensor.SparseTensor, ()), "classes": ((), sparse_tensor.SparseTensor, ()),
"expected": ((), tensor_shape.unknown_shape(), ()) "expected": ((), tensor_shape.unknown_shape(), ())
}, },

View File

@ -525,40 +525,43 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
structure.from_tensor_list(s_2, flat_s_1) structure.from_tensor_list(s_2, flat_s_1)
@parameterized.named_parameters( @parameterized.named_parameters(
("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, ("Tensor", dtypes.float32, tensor_shape.TensorShape(
tensor_spec.TensorSpec([], dtypes.float32)), []), ops.Tensor, tensor_spec.TensorSpec([], dtypes.float32)),
("SparseTensor", dtypes.int32, tensor_shape.matrix( ("SparseTensor", dtypes.int32, tensor_shape.TensorShape(
2, 2), sparse_tensor.SparseTensor, [2, 2]), sparse_tensor.SparseTensor,
sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)), sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
("TensorArray_0", dtypes.int32, tensor_shape.as_shape( ("TensorArray_0", dtypes.int32,
[None, True, 2, 2]), tensor_array_ops.TensorArray, tensor_shape.TensorShape([None, True, 2, 2
]), tensor_array_ops.TensorArray,
tensor_array_ops.TensorArraySpec( tensor_array_ops.TensorArraySpec(
[2, 2], dtypes.int32, dynamic_size=None, infer_shape=True)), [2, 2], dtypes.int32, dynamic_size=None, infer_shape=True)),
("TensorArray_1", dtypes.int32, tensor_shape.as_shape( ("TensorArray_1", dtypes.int32,
[True, None, 2, 2]), tensor_array_ops.TensorArray, tensor_shape.TensorShape([True, None, 2, 2
]), tensor_array_ops.TensorArray,
tensor_array_ops.TensorArraySpec( tensor_array_ops.TensorArraySpec(
[2, 2], dtypes.int32, dynamic_size=True, infer_shape=None)), [2, 2], dtypes.int32, dynamic_size=True, infer_shape=None)),
("TensorArray_2", dtypes.int32, tensor_shape.as_shape( ("TensorArray_2", dtypes.int32,
[True, False, 2, 2]), tensor_array_ops.TensorArray, tensor_shape.TensorShape([True, False, 2, 2
]), tensor_array_ops.TensorArray,
tensor_array_ops.TensorArraySpec( tensor_array_ops.TensorArraySpec(
[2, 2], dtypes.int32, dynamic_size=True, infer_shape=False)), [2, 2], dtypes.int32, dynamic_size=True, infer_shape=False)),
("RaggedTensor", dtypes.int32, tensor_shape.matrix( ("RaggedTensor", dtypes.int32, tensor_shape.TensorShape([2, None]),
2, None), ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1), ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)), ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
("Nested", { ("Nested", {
"a": dtypes.float32, "a": dtypes.float32,
"b": (dtypes.int32, dtypes.string) "b": (dtypes.int32, dtypes.string)
}, { }, {
"a": tensor_shape.scalar(), "a": tensor_shape.TensorShape([]),
"b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([]))
}, { }, {
"a": ops.Tensor, "a": ops.Tensor,
"b": (sparse_tensor.SparseTensor, ops.Tensor) "b": (sparse_tensor.SparseTensor, ops.Tensor)
}, { }, {
"a": "a":
tensor_spec.TensorSpec([], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32),
"b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), "b": (sparse_tensor.SparseTensorSpec(
tensor_spec.TensorSpec([], dtypes.string)) [2, 2], dtypes.int32), tensor_spec.TensorSpec([], dtypes.string))
}), }),
) )
def testConvertLegacyStructure(self, output_types, output_shapes, def testConvertLegacyStructure(self, output_types, output_shapes,

View File

@ -683,7 +683,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
compiled = def_function.function(f) compiled = def_function.function(f)
var_handle = compiled() var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource) self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.scalar()) self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
@ -760,7 +760,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
compiled = def_function.function(f) compiled = def_function.function(f)
var_handle = compiled() var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource) self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.scalar()) self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
@ -790,14 +790,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def f(): def f():
tl, value = list_ops.tensor_list_pop_back( tl, value = list_ops.tensor_list_pop_back(
tensor_list, element_dtype=dtypes.float32) tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.scalar()) self.assertEqual(value.shape, tensor_shape.TensorShape([]))
return tl return tl
compiled = def_function.function(f) compiled = def_function.function(f)
output_tensor_list = compiled() output_tensor_list = compiled()
_, value = list_ops.tensor_list_pop_back( _, value = list_ops.tensor_list_pop_back(
output_tensor_list, element_dtype=dtypes.float32) output_tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.scalar()) self.assertEqual(value.shape, tensor_shape.TensorShape([]))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self): def testDefunForcesResourceVariables(self):

View File

@ -2462,7 +2462,7 @@ class _EmbeddingColumn(
@property @property
def _variable_shape(self): def _variable_shape(self):
if not hasattr(self, '_shape'): if not hasattr(self, '_shape'):
self._shape = tensor_shape.vector(self.dimension) self._shape = tensor_shape.TensorShape([self.dimension])
return self._shape return self._shape
def _get_dense_tensor_internal(self, def _get_dense_tensor_internal(self,
@ -2573,7 +2573,7 @@ class _SharedEmbeddingColumn(
@property @property
def _variable_shape(self): def _variable_shape(self):
if not hasattr(self, '_shape'): if not hasattr(self, '_shape'):
self._shape = tensor_shape.vector(self.dimension) self._shape = tensor_shape.TensorShape([self.dimension])
return self._shape return self._shape
def _get_dense_tensor_internal(self, def _get_dense_tensor_internal(self,

View File

@ -3134,7 +3134,7 @@ class EmbeddingColumn(
@property @property
def variable_shape(self): def variable_shape(self):
"""See `DenseColumn` base class.""" """See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension) return tensor_shape.TensorShape([self.dimension])
@property @property
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
@ -3418,7 +3418,8 @@ class SharedEmbeddingColumn(
@property @property
def variable_shape(self): def variable_shape(self):
"""See `DenseColumn` base class.""" """See `DenseColumn` base class."""
return tensor_shape.vector(self.shared_embedding_column_creator.dimension) return tensor_shape.TensorShape(
[self.shared_embedding_column_creator.dimension])
@property @property
def _variable_shape(self): def _variable_shape(self):

View File

@ -42,7 +42,7 @@ def rank(tensor):
def scalar_shape(unused_op): def scalar_shape(unused_op):
"""Shape function for ops that output a scalar value.""" """Shape function for ops that output a scalar value."""
return [tensor_shape.scalar()] return [tensor_shape.TensorShape([])]
def unchanged_shape(op): def unchanged_shape(op):

View File

@ -63,11 +63,11 @@ class CommonShapesTest(test_util.TensorFlowTestCase):
self.assertEqual(expected, common_shapes.broadcast_shape(shape2, shape1)) self.assertEqual(expected, common_shapes.broadcast_shape(shape2, shape1))
def testBroadcast_one_dimension(self): def testBroadcast_one_dimension(self):
s1 = tensor_shape.vector(5) s1 = tensor_shape.TensorShape([5])
s2 = tensor_shape.vector(7) s2 = tensor_shape.TensorShape([7])
unknown = tensor_shape.unknown_shape() unknown = tensor_shape.unknown_shape()
scalar = tensor_shape.scalar() scalar = tensor_shape.TensorShape([])
expanded_scalar = tensor_shape.TensorShape([1]) expanded_scalar = tensor_shape.TensorShape([1])
# Tensors with same shape should have the same broadcast result. # Tensors with same shape should have the same broadcast result.
@ -90,13 +90,13 @@ class CommonShapesTest(test_util.TensorFlowTestCase):
def testBroadcast_many_dimensions(self): def testBroadcast_many_dimensions(self):
unknown = tensor_shape.unknown_shape() unknown = tensor_shape.unknown_shape()
shape_0 = tensor_shape.scalar() shape_0 = tensor_shape.TensorShape([])
shape_1 = tensor_shape.vector(1) shape_1 = tensor_shape.TensorShape([1])
shape_4 = tensor_shape.vector(4) shape_4 = tensor_shape.TensorShape([4])
shape_1x4 = tensor_shape.matrix(1, 4) shape_1x4 = tensor_shape.TensorShape([1, 4])
shape_4x1 = tensor_shape.matrix(4, 1) shape_4x1 = tensor_shape.TensorShape([4, 1])
shape_3x4 = tensor_shape.matrix(3, 4) shape_3x4 = tensor_shape.TensorShape([3, 4])
shape_4x3 = tensor_shape.matrix(4, 3) shape_4x3 = tensor_shape.TensorShape([4, 3])
# Tensors with same shape should have the same broadcast result. # Tensors with same shape should have the same broadcast result.
for shape in ( for shape in (
@ -113,7 +113,7 @@ class CommonShapesTest(test_util.TensorFlowTestCase):
self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown) self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown)
self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4) self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4)
shape_4x4 = tensor_shape.matrix(4, 4) shape_4x4 = tensor_shape.TensorShape([4, 4])
self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1) self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1)
self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4) self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4)
self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3) self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3)
@ -155,14 +155,14 @@ class CommonShapesTest(test_util.TensorFlowTestCase):
def testBroadcast_unknown_dims(self): def testBroadcast_unknown_dims(self):
unknown = tensor_shape.unknown_shape() unknown = tensor_shape.unknown_shape()
shape_0 = tensor_shape.scalar() shape_0 = tensor_shape.TensorShape([])
shape_1 = tensor_shape.vector(1) shape_1 = tensor_shape.TensorShape([1])
# pylint: disable=invalid-name # pylint: disable=invalid-name
shape_U = tensor_shape.vector(None) shape_U = tensor_shape.TensorShape([None])
shape_1xU = tensor_shape.matrix(1, None) shape_1xU = tensor_shape.TensorShape([1, None])
shape_Ux1 = tensor_shape.matrix(None, 1) shape_Ux1 = tensor_shape.TensorShape([None, 1])
shape_4xU = tensor_shape.matrix(4, None) shape_4xU = tensor_shape.TensorShape([4, None])
shape_Ux4 = tensor_shape.matrix(None, 4) shape_Ux4 = tensor_shape.TensorShape([None, 4])
# pylint: enable=invalid-name # pylint: enable=invalid-name
# Tensors with same shape should have the same broadcast result. # Tensors with same shape should have the same broadcast result.
@ -183,7 +183,7 @@ class CommonShapesTest(test_util.TensorFlowTestCase):
self._assert_broadcast_with_unknown_dims( self._assert_broadcast_with_unknown_dims(
expected=shape_1xU, shape1=shape_U, shape2=shape_1xU) expected=shape_1xU, shape1=shape_U, shape2=shape_1xU)
shape_UxU = tensor_shape.matrix(None, None) # pylint: disable=invalid-name shape_UxU = tensor_shape.TensorShape([None, None]) # pylint: disable=invalid-name
self._assert_broadcast_with_unknown_dims( self._assert_broadcast_with_unknown_dims(
expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1) expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1)
self._assert_broadcast_with_unknown_dims( self._assert_broadcast_with_unknown_dims(
@ -200,7 +200,7 @@ class CommonShapesTest(test_util.TensorFlowTestCase):
expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU) expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU)
self._assert_broadcast_with_unknown_dims( self._assert_broadcast_with_unknown_dims(
expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4) expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4)
shape_4x4 = tensor_shape.matrix(4, 4) shape_4x4 = tensor_shape.TensorShape([4, 4])
self._assert_broadcast_with_unknown_dims( self._assert_broadcast_with_unknown_dims(
expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4) expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4)

View File

@ -75,15 +75,18 @@ class FunctionDefToGraphTest(test.TestCase):
self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims.
g = function_def_to_graph.function_def_to_graph( g = function_def_to_graph.function_def_to_graph(
fdef, input_shapes=[tensor_shape.vector(5), fdef,
tensor_shape.vector(5)]) input_shapes=[
tensor_shape.TensorShape([5]),
tensor_shape.TensorShape([5])
])
self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5])
g = function_def_to_graph.function_def_to_graph( g = function_def_to_graph.function_def_to_graph(
fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) fdef, input_shapes=[None, tensor_shape.TensorShape([5, 7])])
self.assertIsNone(g.inputs[0].shape.dims) self.assertIsNone(g.inputs[0].shape.dims)
self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
@ -93,7 +96,7 @@ class FunctionDefToGraphTest(test.TestCase):
# the number of input args in FunctionDef.signature.input_arg. # the number of input args in FunctionDef.signature.input_arg.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
g = function_def_to_graph.function_def_to_graph( g = function_def_to_graph.function_def_to_graph(
fdef, input_shapes=[tensor_shape.matrix(5, 7)]) fdef, input_shapes=[tensor_shape.TensorShape([5, 7])])
class FunctionDefToGraphDefTest(test.TestCase): class FunctionDefToGraphDefTest(test.TestCase):
@ -177,8 +180,10 @@ class FunctionDefToGraphDefTest(test.TestCase):
fdef = self._build_function_def() fdef = self._build_function_def()
g, _ = function_def_to_graph.function_def_to_graph_def( g, _ = function_def_to_graph.function_def_to_graph_def(
fdef, fdef,
input_shapes=[tensor_shape.scalar(), input_shapes=[
tensor_shape.vector(5), None]) tensor_shape.TensorShape([]),
tensor_shape.TensorShape([5]), None
])
self.assertEqual("shape" in g.node[0].attr, True) self.assertEqual("shape" in g.node[0].attr, True)
self.assertSequenceEqual( self.assertSequenceEqual(
tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), []) tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), [])

View File

@ -136,7 +136,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([]) b = array_ops.ones([])
c = a + b c = a + b
self.assertEqual(tensor_shape.scalar(), c.shape) self.assertEqual(tensor_shape.TensorShape([]), c.shape)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testShapeFunctionError(self): def testShapeFunctionError(self):
@ -783,7 +783,7 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
self.assertEqual(op.name, "myop") self.assertEqual(op.name, "myop")
self.assertEqual(op.type, "Identity") self.assertEqual(op.type, "Identity")
self.assertEqual(len(op.outputs), 1) self.assertEqual(len(op.outputs), 1)
self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3]))
def testUniqueName(self): def testUniqueName(self):
g = ops.Graph() g = ops.Graph()

View File

@ -22,6 +22,7 @@ from tensorflow.python import tf2
from tensorflow.python.eager import monitoring from tensorflow.python.eager import monitoring
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
_TENSORSHAPE_V2_OVERRIDE = None _TENSORSHAPE_V2_OVERRIDE = None
@ -1238,11 +1239,13 @@ def unknown_shape(rank=None, **kwargs):
return TensorShape([Dimension(None)] * rank) return TensorShape([Dimension(None)] * rank)
@deprecation.deprecated(None, "Use tf.TensorShape([]).")
def scalar(): def scalar():
"""Returns a shape representing a scalar.""" """Returns a shape representing a scalar."""
return TensorShape([]) return TensorShape([])
@deprecation.deprecated(None, "Use tf.TensorShape([length]).")
def vector(length): def vector(length):
"""Returns a shape representing a vector. """Returns a shape representing a vector.
@ -1255,6 +1258,7 @@ def vector(length):
return TensorShape([length]) return TensorShape([length])
@deprecation.deprecated(None, "Use tf.TensorShape([rows, cols]).")
def matrix(rows, cols): def matrix(rows, cols):
"""Returns a shape representing a matrix. """Returns a shape representing a matrix.

View File

@ -377,14 +377,6 @@ class ShapeTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self._testMostSpecificCompatibleShapeHelper([1, 1, 3], [None, 2, 3], self._testMostSpecificCompatibleShapeHelper([1, 1, 3], [None, 2, 3],
[None, None, 3]) [None, None, 3])
def testHelpers(self):
tensor_shape.TensorShape([]).assert_is_compatible_with(
tensor_shape.scalar())
tensor_shape.TensorShape([37]).assert_is_compatible_with(
tensor_shape.vector(37))
tensor_shape.TensorShape(
[94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
def testTruedivFails(self): def testTruedivFails(self):
unknown = tensor_shape.Dimension(None) unknown = tensor_shape.Dimension(None)
self.assertEqual((unknown // unknown).value, None) self.assertEqual((unknown // unknown).value, None)
@ -430,9 +422,9 @@ class ShapeTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertEqual( self.assertEqual(
"(32, None, 1, 9)", "(32, None, 1, 9)",
str(tensor_shape.TensorShape([32, None, 1, 9])).replace("?", "None")) str(tensor_shape.TensorShape([32, None, 1, 9])).replace("?", "None"))
self.assertEqual("()", str(tensor_shape.scalar())) self.assertEqual("()", str(tensor_shape.TensorShape([])))
self.assertEqual("(7,)", str(tensor_shape.vector(7))) self.assertEqual("(7,)", str(tensor_shape.TensorShape([7])))
self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8))) self.assertEqual("(3, 8)", str(tensor_shape.TensorShape([3, 8])))
self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2]))) self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))
def testAsProto(self): def testAsProto(self):

View File

@ -833,11 +833,11 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
shape = tensor.get_shape().with_rank(1) shape = tensor.get_shape().with_rank(1)
if shape == [0]: if shape == [0]:
return tensor_shape.scalar() return tensor_shape.TensorShape([])
elif tensor.op.type == "Shape": elif tensor.op.type == "Shape":
return tensor.op.inputs[0].get_shape() return tensor.op.inputs[0].get_shape()
elif tensor.op.type == "Pack": elif tensor.op.type == "Pack":
ret = tensor_shape.scalar() # Empty list. ret = tensor_shape.TensorShape([]) # Empty list.
# Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
# would not be rank 1. # would not be rank 1.
assert tensor.op.get_attr("axis") == 0 assert tensor.op.get_attr("axis") == 0
@ -855,7 +855,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
# We assume that `tensor.op.inputs[0]` evaluates to 0, as this is # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
# the only legal value when concatenating vectors, and it will # the only legal value when concatenating vectors, and it will
# have been checked by a previous shape function. # have been checked by a previous shape function.
ret = tensor_shape.scalar() # Empty list. ret = tensor_shape.TensorShape([]) # Empty list.
for concat_input in tensor.op.inputs[1:]: for concat_input in tensor.op.inputs[1:]:
# `concat_input` must be a vector. Attempt to evaluate it as a shape, # `concat_input` must be a vector. Attempt to evaluate it as a shape,
# and concatenate it with `ret`. # and concatenate it with `ret`.
@ -865,7 +865,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
# We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
# the only legal value when concatenating vectors, and it will # the only legal value when concatenating vectors, and it will
# have been checked by a previous shape function. # have been checked by a previous shape function.
ret = tensor_shape.scalar() # Empty list. ret = tensor_shape.TensorShape([]) # Empty list.
for concat_input in tensor.op.inputs[:-1]: for concat_input in tensor.op.inputs[:-1]:
# `concat_input` must be a vector. Attempt to evaluate it as a shape, # `concat_input` must be a vector. Attempt to evaluate it as a shape,
# and concatenate it with `ret`. # and concatenate it with `ret`.

View File

@ -129,8 +129,9 @@ class GrapplerTest(test.TestCase):
mg = meta_graph.create_meta_graph_def(graph=g) mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg) grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties() op_properties = grappler_item.GetOpProperties()
self.assertEqual(tensor_shape.scalar(), self.assertEqual(
op_properties['IteratorGetNext'][0].shape) tensor_shape.TensorShape([]),
op_properties['IteratorGetNext'][0].shape)
def _testTransformation(self, fn): def _testTransformation(self, fn):
test_cases = [{ test_cases = [{

View File

@ -80,7 +80,7 @@ class ItemTest(test.TestCase):
else: else:
self.assertEqual(1, len(node_prop)) self.assertEqual(1, len(node_prop))
self.assertEqual(dtypes.int32, node_prop[0].dtype) self.assertEqual(dtypes.int32, node_prop[0].dtype)
self.assertEqual(tensor_shape.scalar(), node_prop[0].shape) self.assertEqual(tensor_shape.TensorShape([]), node_prop[0].shape)
def testUpdates(self): def testUpdates(self):
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:

View File

@ -391,7 +391,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
b = control_flow_ops.cond( b = control_flow_ops.cond(
constant_op.constant(True), lambda: math_ops.square(x), constant_op.constant(True), lambda: math_ops.square(x),
lambda: math_ops.subtract(x, 1.)) lambda: math_ops.subtract(x, 1.))
self.assertEqual(b.shape, tensor_shape.scalar()) self.assertEqual(b.shape, tensor_shape.TensorShape([]))
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testFetchable(self): def testFetchable(self):

View File

@ -1166,10 +1166,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertEqual(fn(tensor_shape.unknown_shape()), -1) self.assertEqual(fn(tensor_shape.unknown_shape()), -1)
# Scalar shape -> [] with type int32. # Scalar shape -> [] with type int32.
self.assertEqual(fn([]).dtype, dtypes.int32) self.assertEqual(fn([]).dtype, dtypes.int32)
self.assertEqual(fn(tensor_shape.scalar()).dtype, dtypes.int32) self.assertEqual(fn(tensor_shape.TensorShape([])).dtype, dtypes.int32)
self.assertAllEqual(self.evaluate(fn([])), np.array([], np.int32)) self.assertAllEqual(self.evaluate(fn([])), np.array([], np.int32))
self.assertAllEqual( self.assertAllEqual(
self.evaluate(fn(tensor_shape.scalar())), np.array([], np.int32)) self.evaluate(fn(tensor_shape.TensorShape([]))), np.array([], np.int32))
# Tensor -> Tensor # Tensor -> Tensor
shape = constant_op.constant(1) shape = constant_op.constant(1)
self.assertIs(fn(shape), shape) self.assertIs(fn(shape), shape)
@ -1327,7 +1327,8 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testConcatListWithScalarElementShapeFails(self): def testConcatListWithScalarElementShapeFails(self):
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=tensor_shape.scalar()) element_dtype=dtypes.float32,
element_shape=tensor_shape.TensorShape([]))
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
errors.InvalidArgumentError, errors.InvalidArgumentError,
"Concat requires elements to be at least vectors, " "Concat requires elements to be at least vectors, "

View File

@ -1034,7 +1034,7 @@ class TensorArrayTest(test.TestCase):
dtype=dtypes.float32, dtype=dtypes.float32,
size=num_steps, size=num_steps,
clear_after_read=False, clear_after_read=False,
element_shape=tensor_shape.scalar()) element_shape=tensor_shape.TensorShape([]))
i = constant_op.constant(0, name="i") i = constant_op.constant(0, name="i")
c = lambda i, acc: i < 5 c = lambda i, acc: i < 5
@ -1693,10 +1693,10 @@ class TensorArrayTest(test.TestCase):
self.assertEqual(dtypes.float32, ta0.dtype) self.assertEqual(dtypes.float32, ta0.dtype)
self.assertEqual(dtypes.int32, ta1.dtype) self.assertEqual(dtypes.int32, ta1.dtype)
if context.executing_eagerly(): if context.executing_eagerly():
self.assertEqual(tensor_shape.scalar(), read0.get_shape()) self.assertEqual(tensor_shape.TensorShape([]), read0.get_shape())
else: else:
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
self.assertEqual(tensor_shape.scalar(), read1.get_shape()) self.assertEqual(tensor_shape.TensorShape([]), read1.get_shape())
if not context.executing_eagerly(): if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())

View File

@ -60,7 +60,7 @@ class AccumulateNBenchmark(test.Benchmark):
return self._AccumulateNTemplate( return self._AccumulateNTemplate(
inputs, inputs,
init=array_ops.zeros_like(gen_control_flow_ops.merge(inputs)[0]), init=array_ops.zeros_like(gen_control_flow_ops.merge(inputs)[0]),
shape=tensor_shape.vector(0), shape=tensor_shape.TensorShape([0]),
validate_shape=False) validate_shape=False)
def _AccumulateNInitializedWithShape(self, inputs): def _AccumulateNInitializedWithShape(self, inputs):

View File

@ -1307,8 +1307,7 @@ def concat(values, axis, name="concat"):
with ops.name_scope(name) as scope: with ops.name_scope(name) as scope:
ops.convert_to_tensor( ops.convert_to_tensor(
axis, name="concat_dim", axis, name="concat_dim",
dtype=dtypes.int32).get_shape().assert_is_compatible_with( dtype=dtypes.int32).get_shape().assert_has_rank(0)
tensor_shape.scalar())
return identity(values[0], name=scope) return identity(values[0], name=scope)
return gen_array_ops.concat_v2(values=values, axis=axis, name=name) return gen_array_ops.concat_v2(values=values, axis=axis, name=name)

View File

@ -1092,8 +1092,8 @@ class Barrier(object):
else: else:
batch_dim = tensor_shape.Dimension( batch_dim = tensor_shape.Dimension(
tensor_util.constant_value(op.inputs[1])) tensor_util.constant_value(op.inputs[1]))
op.outputs[0].set_shape(tensor_shape.vector(batch_dim)) # indices op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim])) # indices
op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim])) # keys
for output, shape in zip(op.outputs[2:], self._shapes): # value_list for output, shape in zip(op.outputs[2:], self._shapes): # value_list
output.set_shape( output.set_shape(
tensor_shape.TensorShape([batch_dim]).concatenate(shape)) tensor_shape.TensorShape([batch_dim]).concatenate(shape))

View File

@ -120,7 +120,7 @@ class Bernoulli(distribution.Distribution):
return array_ops.constant([], dtype=dtypes.int32) return array_ops.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)

View File

@ -238,7 +238,7 @@ class Beta(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
expanded_concentration1 = array_ops.ones_like( expanded_concentration1 = array_ops.ones_like(

View File

@ -266,7 +266,7 @@ class Categorical(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
if self.logits.get_shape().ndims == 2: if self.logits.get_shape().ndims == 2:

View File

@ -210,7 +210,7 @@ class Gamma(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
@distribution_util.AppendDocstring( @distribution_util.AppendDocstring(
"""Note: See `tf.random.gamma` docstring for sampling details and """Note: See `tf.random.gamma` docstring for sampling details and

View File

@ -153,7 +153,7 @@ class Laplace(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)

View File

@ -189,7 +189,7 @@ class Normal(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)

View File

@ -241,7 +241,7 @@ class StudentT(distribution.Distribution):
return constant_op.constant([], dtype=math_ops.int32) return constant_op.constant([], dtype=math_ops.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
# The sampling method comes from the fact that if: # The sampling method comes from the fact that if:

View File

@ -165,7 +165,7 @@ class Uniform(distribution.Distribution):
return constant_op.constant([], dtype=dtypes.int32) return constant_op.constant([], dtype=dtypes.int32)
def _event_shape(self): def _event_shape(self):
return tensor_shape.scalar() return tensor_shape.TensorShape([])
def _sample_n(self, n, seed=None): def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)

View File

@ -166,7 +166,7 @@ class InitializableLookupTableBase(LookupInterface):
initializer.value_dtype) initializer.value_dtype)
self._default_value = ops.convert_to_tensor( self._default_value = ops.convert_to_tensor(
default_value, dtype=self._value_dtype) default_value, dtype=self._value_dtype)
self._default_value.get_shape().merge_with(tensor_shape.scalar()) self._default_value.get_shape().merge_with(tensor_shape.TensorShape([]))
if isinstance(initializer, trackable_base.Trackable): if isinstance(initializer, trackable_base.Trackable):
self._initializer = self._track_trackable(initializer, "_initializer") self._initializer = self._track_trackable(initializer, "_initializer")
with ops.init_scope(): with ops.init_scope():

View File

@ -2282,7 +2282,8 @@ def atrous_conv2d_transpose(value,
data_format="NHWC") data_format="NHWC")
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)): if not output_shape_.get_shape().is_compatible_with(
tensor_shape.TensorShape([4])):
raise ValueError("output_shape must have shape (4,), got {}".format( raise ValueError("output_shape must have shape (4,), got {}".format(
output_shape_.get_shape())) output_shape_.get_shape()))
@ -4233,7 +4234,7 @@ def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
else: else:
rate = ops.convert_to_tensor( rate = ops.convert_to_tensor(
rate, dtype=x.dtype, name="rate") rate, dtype=x.dtype, name="rate")
rate.get_shape().assert_is_compatible_with(tensor_shape.scalar()) rate.get_shape().assert_has_rank(0)
# Do nothing if we know rate == 0 # Do nothing if we know rate == 0
if tensor_util.constant_value(rate) == 0: if tensor_util.constant_value(rate) == 0:

View File

@ -1338,8 +1338,8 @@ class TensorArraySpec(type_spec.TypeSpec):
def _to_legacy_output_shapes(self): def _to_legacy_output_shapes(self):
# Sneak the dynamic_size and infer_shape values into the legacy shape. # Sneak the dynamic_size and infer_shape values into the legacy shape.
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape) return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape
.concatenate(self._element_shape)) ]).concatenate(self._element_shape))
def _to_legacy_output_classes(self): def _to_legacy_output_classes(self):
return TensorArray return TensorArray

View File

@ -107,8 +107,7 @@ def while_loop(cond,
# Add loop counter needed for computing gradients. # Add loop counter needed for computing gradients.
loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars
shape_invariants = ( shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
[tensor_shape.scalar(), tensor_shape.scalar()] + shape_invariants)
signature = ( signature = (
[tensor_spec.TensorSpec.from_tensor(loop_counter), [tensor_spec.TensorSpec.from_tensor(loop_counter),
tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +