[tf.debugging] Make assert_shapes() work on SparseTensors
Fixes https://github.com/tensorflow/tensorflow/issues/36268 PiperOrigin-RevId: 334658746 Change-Id: I53351bf9e8223a37e7cc447651be546fc87cecb6
This commit is contained in:
parent
d3a378f966
commit
11fc1489d0
20
RELEASE.md
20
RELEASE.md
@ -289,6 +289,10 @@
|
||||
|
||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||
|
||||
* `tf.debugging`:
|
||||
|
||||
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||
|
||||
* Other:
|
||||
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
@ -1192,7 +1196,7 @@ This release contains contributions from many people at Google, as well as:
|
||||
8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee
|
||||
|
||||
# Release 1.15.0
|
||||
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
|
||||
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
|
||||
|
||||
## Major Features and Improvements
|
||||
* As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
|
||||
@ -1202,7 +1206,7 @@ This enables writing forward compatible code: by explicitly importing either `te
|
||||
* Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow.
|
||||
* Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`.
|
||||
* AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS.
|
||||
* Adds `enable_tensor_equality()`, which switches the behavior such that:
|
||||
* Adds `enable_tensor_equality()`, which switches the behavior such that:
|
||||
* Tensors are no longer hashable.
|
||||
* Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0.
|
||||
|
||||
@ -1358,12 +1362,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
|
||||
* TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow.
|
||||
* Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation.
|
||||
Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs.
|
||||
|
||||
|
||||
* `tf.contrib`:
|
||||
* `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely.
|
||||
* Remove `tf.contrib.timeseries` dependency on TF distributions.
|
||||
* Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`.
|
||||
|
||||
|
||||
* `tf.estimator`:
|
||||
* Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`.
|
||||
* Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method.
|
||||
@ -1371,13 +1375,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
|
||||
* `Estimator.export_savedmodel` has been renamed to `export_saved_model`.
|
||||
* When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`.
|
||||
* Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead.
|
||||
|
||||
|
||||
* `tf.keras`:
|
||||
* `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs.
|
||||
* `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported.
|
||||
* Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead.
|
||||
* Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information.
|
||||
|
||||
|
||||
* `tf.lite`:
|
||||
* Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API.
|
||||
* Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior.
|
||||
@ -2623,7 +2627,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
|
||||
* [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier)
|
||||
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
|
||||
API supports broadcasting for Bijectors with new API changes.
|
||||
|
||||
|
||||
## Breaking Changes
|
||||
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
|
||||
`variable_scope(tf.get_variable_scope(), ...)`.
|
||||
@ -3102,7 +3106,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim
|
||||
Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan,
|
||||
Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay,
|
||||
Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang,
|
||||
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
|
||||
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
|
||||
|
||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||
answered questions, and were part of inspiring discussions.
|
||||
|
@ -1903,6 +1903,174 @@ class AssertShapesTest(test.TestCase):
|
||||
sess.run(out, feed_dict=feed_dict)
|
||||
|
||||
|
||||
class AssertShapesSparseTensorTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_scalar_target_success(self):
|
||||
sparse_float = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[]], dtypes.int64),
|
||||
constant_op.constant([42], dtypes.float32),
|
||||
constant_op.constant([], dtypes.int64))
|
||||
assertion = check_ops.assert_shapes([(sparse_float, [])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_float)
|
||||
self.evaluate(out)
|
||||
|
||||
def test_assert_shapes_sparse_tensor_nonscalar_target_fail(self):
|
||||
sparse_float = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[]], dtypes.int64),
|
||||
constant_op.constant([42], dtypes.float32),
|
||||
constant_op.constant([], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"must have rank 2.*Received rank 0"):
|
||||
assertion = check_ops.assert_shapes([(sparse_float, [None, None])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_float)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_fully_specified_target_success(self):
|
||||
sparse_float = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[111], [232]], dtypes.int64),
|
||||
constant_op.constant([23.4, -43.2], dtypes.float32),
|
||||
constant_op.constant([500], dtypes.int64))
|
||||
assertion = check_ops.assert_shapes([(sparse_float, [500])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_float)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_fully_specified_target_fail(self):
|
||||
sparse_float = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[111], [232]], dtypes.int64),
|
||||
constant_op.constant([23.4, -43.2], dtypes.float32),
|
||||
constant_op.constant([500], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError, r"dimension 0 must have size 499"):
|
||||
assertion = check_ops.assert_shapes([(sparse_float, [499])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_float)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_partially_specified_target_success(self):
|
||||
sparse_int = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 40], dtypes.int64))
|
||||
assertion = check_ops.assert_shapes([(sparse_int, [None, 40])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_int)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_symbolic_match_success(self):
|
||||
sparse_int = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6, 7], [8, 9, 10]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 30, 40], dtypes.int64))
|
||||
assertion = check_ops.assert_shapes([(sparse_int, ["N", "N", "D"])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_int)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_partially_specified_target_fail(self):
|
||||
sparse_int = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 40], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 41"):
|
||||
assertion = check_ops.assert_shapes([(sparse_int, [None, 41])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_int)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_wrong_rank_fail(self):
|
||||
sparse_int = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 40], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"must have rank 3\..* Received rank 2"):
|
||||
assertion = check_ops.assert_shapes([(sparse_int, [None, None, 40])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_int)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_wrong_symbolic_match_fail(self):
|
||||
sparse_int = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 40], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"):
|
||||
assertion = check_ops.assert_shapes([(sparse_int, ["D", "D"])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_int)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_multiple_assertions_success(self):
|
||||
sparse_scalar = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[]], dtypes.int64),
|
||||
constant_op.constant([42], dtypes.float32),
|
||||
constant_op.constant([], dtypes.int64))
|
||||
sparse_2d = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 30], dtypes.int64))
|
||||
assertion = check_ops.assert_shapes([(sparse_scalar, []),
|
||||
(sparse_2d, ["N", "N"])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_2d)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_multiple_assertions_fail(self):
|
||||
sparse_scalar = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[]], dtypes.int64),
|
||||
constant_op.constant([42], dtypes.float32),
|
||||
constant_op.constant([], dtypes.int64))
|
||||
sparse_2d = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 40], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"):
|
||||
assertion = check_ops.assert_shapes([(sparse_scalar, []),
|
||||
(sparse_2d, ["N", "N"])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_2d)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_success(self):
|
||||
dense_scalar = constant_op.constant([42], dtypes.float32)
|
||||
sparse_2d = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 30], dtypes.int64))
|
||||
assertion = check_ops.assert_shapes([(dense_scalar, []),
|
||||
(sparse_2d, ["N", "N"])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_2d)
|
||||
self.evaluate(out)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_fail(self):
|
||||
dense_scalar = constant_op.constant([42], dtypes.float32)
|
||||
sparse_2d = sparse_tensor.SparseTensor(
|
||||
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
|
||||
constant_op.constant([23, -43], dtypes.int32),
|
||||
constant_op.constant([30, 40], dtypes.int64))
|
||||
with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"):
|
||||
assertion = check_ops.assert_shapes([(dense_scalar, []),
|
||||
(sparse_2d, ["N", "N"])])
|
||||
with ops.control_dependencies([assertion]):
|
||||
out = array_ops.identity(sparse_2d)
|
||||
self.evaluate(out)
|
||||
|
||||
|
||||
class IsStrictlyIncreasingTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
|
@ -1151,14 +1151,15 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
|
||||
ValueError: If static checks determine `x` has wrong rank.
|
||||
"""
|
||||
with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if not isinstance(x, sparse_tensor.SparseTensor):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
rank = ops.convert_to_tensor(rank, name='rank')
|
||||
message = message or ''
|
||||
|
||||
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
|
||||
dynamic_condition = math_ops.equal
|
||||
|
||||
if context.executing_eagerly():
|
||||
if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
|
||||
name = ''
|
||||
else:
|
||||
name = x.name
|
||||
@ -1418,11 +1419,12 @@ def assert_rank_in(
|
||||
"""
|
||||
with ops.name_scope(
|
||||
name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if not isinstance(x, sparse_tensor.SparseTensor):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
|
||||
message = message or ''
|
||||
|
||||
if context.executing_eagerly():
|
||||
if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
|
||||
name = ''
|
||||
else:
|
||||
name = x.name
|
||||
@ -1582,7 +1584,7 @@ def _dimension_sizes(x):
|
||||
rank = x.get_shape().rank
|
||||
rank_is_known = rank is not None
|
||||
if rank_is_known and rank == 0:
|
||||
return tuple([1])
|
||||
return (1,)
|
||||
if rank_is_known and rank > 0:
|
||||
static_shape = x.get_shape().as_list()
|
||||
sizes = [
|
||||
@ -1787,14 +1789,14 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
|
||||
message = message or ''
|
||||
with ops.name_scope(name, 'assert_shapes', [shapes, data]):
|
||||
# Shape specified as None implies no constraint
|
||||
shape_constraints = [
|
||||
(ops.convert_to_tensor(x), s) for x, s in shapes if s is not None
|
||||
]
|
||||
shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else
|
||||
ops.convert_to_tensor(x), s)
|
||||
for x, s in shapes if s is not None]
|
||||
|
||||
executing_eagerly = context.executing_eagerly()
|
||||
|
||||
def tensor_name(x):
|
||||
if executing_eagerly:
|
||||
if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor):
|
||||
return _shape_and_dtype_str(x)
|
||||
return x.name
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user