For TF version 1.x only: update _from_components methods in RaggedTensorSpec, SparseTensorSpec, and IndexedSlicesSpec to construct value objects (RaggedTensorValue etc.) if the components are numpy arrays.
PiperOrigin-RevId: 260072796
This commit is contained in:
parent
0ea0c474d3
commit
69b368ed82
@ -23,6 +23,7 @@ import collections
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -235,7 +236,14 @@ class IndexedSlicesSpec(type_spec.TypeSpec):
|
||||
return (value.values, value.indices, value.dense_shape)
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
return IndexedSlices(*tensor_list)
|
||||
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
|
||||
not tf2.enabled()):
|
||||
if len(tensor_list) == 2:
|
||||
return IndexedSlicesValue(tensor_list[0], tensor_list[1], None)
|
||||
else:
|
||||
return IndexedSlicesValue(*tensor_list)
|
||||
else:
|
||||
return IndexedSlices(*tensor_list)
|
||||
|
||||
|
||||
@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
|
||||
|
||||
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import gc
|
||||
import numpy as np
|
||||
import os
|
||||
import threading
|
||||
import weakref
|
||||
@ -303,6 +304,26 @@ class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
||||
else:
|
||||
self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
|
||||
|
||||
@test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
|
||||
def testFromNumpyComponents(self):
|
||||
indices = np.array([3, 8])
|
||||
values = np.array([1.0, 9.0])
|
||||
dense_shape = np.array([100])
|
||||
|
||||
spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
|
||||
st1 = spec1._from_components((values, indices, dense_shape))
|
||||
self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
|
||||
self.assertAllEqual(st1.indices, indices)
|
||||
self.assertAllEqual(st1.values, values)
|
||||
self.assertAllEqual(st1.dense_shape, dense_shape)
|
||||
|
||||
spec2 = indexed_slices.IndexedSlicesSpec()
|
||||
st2 = spec2._from_components((values, indices))
|
||||
self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
|
||||
self.assertAllEqual(st2.indices, indices)
|
||||
self.assertAllEqual(st2.values, values)
|
||||
self.assertIs(st2.dense_shape, None)
|
||||
|
||||
|
||||
class NodeDefConstructorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -293,7 +294,11 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return [value.indices, value.values, value.dense_shape]
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
return SparseTensor(*tensor_list)
|
||||
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
|
||||
not tf2.enabled()):
|
||||
return SparseTensorValue(*tensor_list)
|
||||
else:
|
||||
return SparseTensor(*tensor_list)
|
||||
|
||||
# The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
|
||||
# to (un)box the component tensors in a way that allows for batching &
|
||||
|
||||
@ -187,6 +187,18 @@ class SparseTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(st.values, st_reconstructed.values)
|
||||
self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)
|
||||
|
||||
@test_util.run_v1_only("SparseTensorValue is deprecated in v2")
|
||||
def testFromNumpyComponents(self):
|
||||
indices = np.array([[0], [8]])
|
||||
values = np.array([1.0, 9.0])
|
||||
dense_shape = np.array([100])
|
||||
spec = sparse_tensor.SparseTensorSpec()
|
||||
st = spec._from_components([indices, values, dense_shape])
|
||||
self.assertIsInstance(st, sparse_tensor.SparseTensorValue)
|
||||
self.assertAllEqual(st.indices, indices)
|
||||
self.assertAllEqual(st.values, values)
|
||||
self.assertAllEqual(st.dense_shape, dense_shape)
|
||||
|
||||
@parameterized.parameters([
|
||||
sparse_tensor.SparseTensorSpec(dtype=dtypes.string),
|
||||
sparse_tensor.SparseTensorSpec(shape=[5, None, None]),
|
||||
|
||||
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -1985,16 +1986,17 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return [value]
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
# Currently, Keras converts tensors to numpy and then calls from_components
|
||||
# with those np.arrays. So if we see np.ndarrays, convert them to tensors.
|
||||
# TODO(b/133606651) Update Keras to do something different here. Consider
|
||||
# adding something like TypeSpec.from_numpy_components?
|
||||
if isinstance(tensor_list[0], np.ndarray):
|
||||
tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
|
||||
|
||||
result = tensor_list[0]
|
||||
for row_splits in reversed(tensor_list[1:]):
|
||||
result = RaggedTensor(result, row_splits, internal=True)
|
||||
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
|
||||
not tf2.enabled()):
|
||||
for row_splits in reversed(tensor_list[1:]):
|
||||
result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
|
||||
else:
|
||||
if isinstance(tensor_list[0], np.ndarray):
|
||||
tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
|
||||
result = tensor_list[0]
|
||||
for row_splits in reversed(tensor_list[1:]):
|
||||
result = RaggedTensor(result, row_splits, internal=True)
|
||||
return result
|
||||
|
||||
# The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
|
||||
|
||||
@ -1655,6 +1655,24 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
rt_reconstructed = rt_spec._from_components(actual_components)
|
||||
self.assertAllEqual(rt, rt_reconstructed)
|
||||
|
||||
@test_util.run_v1_only('RaggedTensorValue is deprecated in v2')
|
||||
def testFromNumpyComponents(self):
|
||||
spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32)
|
||||
rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])])
|
||||
self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue)
|
||||
self.assertAllEqual(rt1, [[1, 2], [3]])
|
||||
|
||||
spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
|
||||
rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]),
|
||||
np.array([0, 0, 2, 3])])
|
||||
self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
|
||||
self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])
|
||||
|
||||
spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32)
|
||||
rt3 = spec3._from_components([np.array([1, 2, 3])])
|
||||
self.assertIsInstance(rt3, np.ndarray)
|
||||
self.assertAllEqual(rt3, [1, 2, 3])
|
||||
|
||||
@parameterized.parameters([
|
||||
RaggedTensorSpec(ragged_rank=0, shape=[5, 3]),
|
||||
RaggedTensorSpec(ragged_rank=1),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user