For TF version 1.x only: update _from_components methods in RaggedTensorSpec, SparseTensorSpec, and IndexedSlicesSpec to construct value objects (RaggedTensorValue etc.) if the components are numpy arrays.

PiperOrigin-RevId: 260072796
This commit is contained in:
Edward Loper 2019-07-25 20:12:20 -07:00 committed by TensorFlower Gardener
parent 0ea0c474d3
commit 69b368ed82
6 changed files with 77 additions and 11 deletions

View File

@ -23,6 +23,7 @@ import collections
import warnings
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
@ -235,7 +236,14 @@ class IndexedSlicesSpec(type_spec.TypeSpec):
return (value.values, value.indices, value.dense_shape)
def _from_components(self, tensor_list):
return IndexedSlices(*tensor_list)
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
not tf2.enabled()):
if len(tensor_list) == 2:
return IndexedSlicesValue(tensor_list[0], tensor_list[1], None)
else:
return IndexedSlicesValue(*tensor_list)
else:
return IndexedSlices(*tensor_list)
@tf_export(v1=["convert_to_tensor_or_indexed_slices"])

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import gc
import numpy as np
import os
import threading
import weakref
@ -303,6 +304,26 @@ class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
else:
self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
@test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
def testFromNumpyComponents(self):
indices = np.array([3, 8])
values = np.array([1.0, 9.0])
dense_shape = np.array([100])
spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
st1 = spec1._from_components((values, indices, dense_shape))
self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
self.assertAllEqual(st1.indices, indices)
self.assertAllEqual(st1.values, values)
self.assertAllEqual(st1.dense_shape, dense_shape)
spec2 = indexed_slices.IndexedSlicesSpec()
st2 = spec2._from_components((values, indices))
self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
self.assertAllEqual(st2.indices, indices)
self.assertAllEqual(st2.values, values)
self.assertIs(st2.dense_shape, None)
class NodeDefConstructorTest(test_util.TensorFlowTestCase):

View File

@ -22,6 +22,7 @@ import collections
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import tf2
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -293,7 +294,11 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
return [value.indices, value.values, value.dense_shape]
def _from_components(self, tensor_list):
return SparseTensor(*tensor_list)
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
not tf2.enabled()):
return SparseTensorValue(*tensor_list)
else:
return SparseTensor(*tensor_list)
# The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
# to (un)box the component tensors in a way that allows for batching &

View File

@ -187,6 +187,18 @@ class SparseTensorSpecTest(test_util.TensorFlowTestCase,
self.assertAllEqual(st.values, st_reconstructed.values)
self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)
@test_util.run_v1_only("SparseTensorValue is deprecated in v2")
def testFromNumpyComponents(self):
indices = np.array([[0], [8]])
values = np.array([1.0, 9.0])
dense_shape = np.array([100])
spec = sparse_tensor.SparseTensorSpec()
st = spec._from_components([indices, values, dense_shape])
self.assertIsInstance(st, sparse_tensor.SparseTensorValue)
self.assertAllEqual(st.indices, indices)
self.assertAllEqual(st.values, values)
self.assertAllEqual(st.dense_shape, dense_shape)
@parameterized.parameters([
sparse_tensor.SparseTensorSpec(dtype=dtypes.string),
sparse_tensor.SparseTensorSpec(shape=[5, None, None]),

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
@ -1985,16 +1986,17 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
return [value]
def _from_components(self, tensor_list):
# Currently, Keras converts tensors to numpy and then calls from_components
# with those np.arrays. So if we see np.ndarrays, convert them to tensors.
# TODO(b/133606651) Update Keras to do something different here. Consider
# adding something like TypeSpec.from_numpy_components?
if isinstance(tensor_list[0], np.ndarray):
tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
result = tensor_list[0]
for row_splits in reversed(tensor_list[1:]):
result = RaggedTensor(result, row_splits, internal=True)
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
not tf2.enabled()):
for row_splits in reversed(tensor_list[1:]):
result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
else:
if isinstance(tensor_list[0], np.ndarray):
tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
result = tensor_list[0]
for row_splits in reversed(tensor_list[1:]):
result = RaggedTensor(result, row_splits, internal=True)
return result
# The RaggedTensorSpec tensor_list encoding uses to/from_variant ops

View File

@ -1655,6 +1655,24 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
rt_reconstructed = rt_spec._from_components(actual_components)
self.assertAllEqual(rt, rt_reconstructed)
@test_util.run_v1_only('RaggedTensorValue is deprecated in v2')
def testFromNumpyComponents(self):
spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32)
rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])])
self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue)
self.assertAllEqual(rt1, [[1, 2], [3]])
spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]),
np.array([0, 0, 2, 3])])
self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])
spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32)
rt3 = spec3._from_components([np.array([1, 2, 3])])
self.assertIsInstance(rt3, np.ndarray)
self.assertAllEqual(rt3, [1, 2, 3])
@parameterized.parameters([
RaggedTensorSpec(ragged_rank=0, shape=[5, 3]),
RaggedTensorSpec(ragged_rank=1),