diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py index 2063680c034..8bc21cac682 100644 --- a/tensorflow/python/framework/indexed_slices.py +++ b/tensorflow/python/framework/indexed_slices.py @@ -23,6 +23,7 @@ import collections import warnings import numpy as np +from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes @@ -235,7 +236,14 @@ class IndexedSlicesSpec(type_spec.TypeSpec): return (value.values, value.indices, value.dense_shape) def _from_components(self, tensor_list): - return IndexedSlices(*tensor_list) + if (all(isinstance(t, np.ndarray) for t in tensor_list) and + not tf2.enabled()): + if len(tensor_list) == 2: + return IndexedSlicesValue(tensor_list[0], tensor_list[1], None) + else: + return IndexedSlicesValue(*tensor_list) + else: + return IndexedSlices(*tensor_list) @tf_export(v1=["convert_to_tensor_or_indexed_slices"]) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index e6520aecb53..d171f90af54 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import gc +import numpy as np import os import threading import weakref @@ -303,6 +304,26 @@ class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, else: self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape) + @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2") + def testFromNumpyComponents(self): + indices = np.array([3, 8]) + values = np.array([1.0, 9.0]) + dense_shape = np.array([100]) + + spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32) + st1 = spec1._from_components((values, indices, dense_shape)) + self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue) + self.assertAllEqual(st1.indices, indices) + self.assertAllEqual(st1.values, values) + self.assertAllEqual(st1.dense_shape, dense_shape) + + spec2 = indexed_slices.IndexedSlicesSpec() + st2 = spec2._from_components((values, indices)) + self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue) + self.assertAllEqual(st2.indices, indices) + self.assertAllEqual(st2.values, values) + self.assertIs(st2.dense_shape, None) + class NodeDefConstructorTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 788d0e97faf..587ea17dc40 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -22,6 +22,7 @@ import collections import numpy as np from tensorflow.python import pywrap_tensorflow +from tensorflow.python import tf2 from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -293,7 +294,11 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec): return [value.indices, value.values, value.dense_shape] def _from_components(self, tensor_list): - return SparseTensor(*tensor_list) + if (all(isinstance(t, np.ndarray) for t in tensor_list) and + not tf2.enabled()): + return SparseTensorValue(*tensor_list) + else: + return SparseTensor(*tensor_list) # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops # to (un)box the component tensors in a way that allows for batching & diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index b193ebfcedc..cc145b36704 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -187,6 +187,18 @@ class SparseTensorSpecTest(test_util.TensorFlowTestCase, self.assertAllEqual(st.values, st_reconstructed.values) self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape) + @test_util.run_v1_only("SparseTensorValue is deprecated in v2") + def testFromNumpyComponents(self): + indices = np.array([[0], [8]]) + values = np.array([1.0, 9.0]) + dense_shape = np.array([100]) + spec = sparse_tensor.SparseTensorSpec() + st = spec._from_components([indices, values, dense_shape]) + self.assertIsInstance(st, sparse_tensor.SparseTensorValue) + self.assertAllEqual(st.indices, indices) + self.assertAllEqual(st.values, values) + self.assertAllEqual(st.dense_shape, dense_shape) + @parameterized.parameters([ sparse_tensor.SparseTensorSpec(dtype=dtypes.string), sparse_tensor.SparseTensorSpec(shape=[5, None, None]), diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index d06819cbf90..b9c3193c286 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.python.client import session from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op @@ -1985,16 +1986,17 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): return [value] def _from_components(self, tensor_list): - # Currently, Keras converts tensors to numpy and then calls from_components - # with those np.arrays. So if we see np.ndarrays, convert them to tensors. - # TODO(b/133606651) Update Keras to do something different here. Consider - # adding something like TypeSpec.from_numpy_components? - if isinstance(tensor_list[0], np.ndarray): - tensor_list = [ops.convert_to_tensor(t) for t in tensor_list] - result = tensor_list[0] - for row_splits in reversed(tensor_list[1:]): - result = RaggedTensor(result, row_splits, internal=True) + if (all(isinstance(t, np.ndarray) for t in tensor_list) and + not tf2.enabled()): + for row_splits in reversed(tensor_list[1:]): + result = ragged_tensor_value.RaggedTensorValue(result, row_splits) + else: + if isinstance(tensor_list[0], np.ndarray): + tensor_list = [ops.convert_to_tensor(t) for t in tensor_list] + result = tensor_list[0] + for row_splits in reversed(tensor_list[1:]): + result = RaggedTensor(result, row_splits, internal=True) return result # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index 453a5208a40..edbd84414da 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -1655,6 +1655,24 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase, rt_reconstructed = rt_spec._from_components(actual_components) self.assertAllEqual(rt, rt_reconstructed) + @test_util.run_v1_only('RaggedTensorValue is deprecated in v2') + def testFromNumpyComponents(self): + spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32) + rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])]) + self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue) + self.assertAllEqual(rt1, [[1, 2], [3]]) + + spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32) + rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]), + np.array([0, 0, 2, 3])]) + self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue) + self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]]) + + spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32) + rt3 = spec3._from_components([np.array([1, 2, 3])]) + self.assertIsInstance(rt3, np.ndarray) + self.assertAllEqual(rt3, [1, 2, 3]) + @parameterized.parameters([ RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), RaggedTensorSpec(ragged_rank=1),