Keras was the only user of the fully-internal composite_tensor_utils, but it forked all of the symbols. So, this CL removes it from TF core. Also moves the test into the Keras training_v1_utils test now that all of the methods from composite_tensor_utils have been moved into training_v1_utils.
PiperOrigin-RevId: 342906866 Change-Id: I7f8d01009d37d9b5b68f569b7808e04808021723
This commit is contained in:
parent
da495851f9
commit
40c932c10a
@ -2180,22 +2180,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "composite_tensor_utils",
|
|
||||||
srcs = ["framework/composite_tensor_utils.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":array_ops",
|
|
||||||
":composite_tensor",
|
|
||||||
":sparse_ops",
|
|
||||||
":sparse_tensor",
|
|
||||||
"//tensorflow/python/ops/ragged:ragged_concat_ops",
|
|
||||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
|
||||||
"//tensorflow/python/ops/ragged:ragged_tensor_value",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "framework_composite_tensor_test",
|
name = "framework_composite_tensor_test",
|
||||||
srcs = ["framework/composite_tensor_test.py"],
|
srcs = ["framework/composite_tensor_test.py"],
|
||||||
@ -2212,24 +2196,6 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "framework_composite_tensor_utils_test",
|
|
||||||
srcs = ["framework/composite_tensor_utils_test.py"],
|
|
||||||
main = "framework/composite_tensor_utils_test.py",
|
|
||||||
python_version = "PY3",
|
|
||||||
deps = [
|
|
||||||
":array_ops",
|
|
||||||
":composite_tensor",
|
|
||||||
":composite_tensor_utils",
|
|
||||||
":framework_test_lib",
|
|
||||||
":sparse_ops",
|
|
||||||
":sparse_tensor",
|
|
||||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
|
||||||
"//tensorflow/python/ops/ragged:ragged_tensor_value",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tensor_shape",
|
name = "tensor_shape",
|
||||||
srcs = ["framework/tensor_shape.py"],
|
srcs = ["framework/tensor_shape.py"],
|
||||||
|
@ -1,160 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Helpers for handling composite tensors and composite tensor values."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import sparse_tensor
|
|
||||||
from tensorflow.python.ops import sparse_ops
|
|
||||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
|
||||||
|
|
||||||
|
|
||||||
def is_composite_or_composite_value(tensor):
|
|
||||||
"""Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
|
|
||||||
# TODO(b/125094323): This should be isinstance(CompositeTensor) or
|
|
||||||
# isinstance(CompositeTensorValue) once we support that.
|
|
||||||
return isinstance(
|
|
||||||
tensor,
|
|
||||||
(composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
|
|
||||||
ragged_tensor_value.RaggedTensorValue))
|
|
||||||
|
|
||||||
|
|
||||||
def get_shape(tensor):
|
|
||||||
"""Returns the shape of the passed composite tensor."""
|
|
||||||
if isinstance(tensor, sparse_tensor.SparseTensorValue):
|
|
||||||
# SparseTensorValues use a 'dense_shape' attribute
|
|
||||||
return tensor.dense_shape
|
|
||||||
else:
|
|
||||||
return tensor.shape
|
|
||||||
|
|
||||||
|
|
||||||
def _append_sparse_tensor_value(target, to_append):
|
|
||||||
"""Append sparse tensor value objects."""
|
|
||||||
# Make sure the sparse tensors are of the same size (except for the 0th dim).
|
|
||||||
if len(target.dense_shape) != len(to_append.dense_shape):
|
|
||||||
raise RuntimeError(
|
|
||||||
'Unable to concatenate %s and %s. The inner dense shapes do not '
|
|
||||||
'have the same number of dimensions (%s vs %s)' %
|
|
||||||
(target, to_append, target.dense_shape, to_append.dense_shape))
|
|
||||||
|
|
||||||
if target.dense_shape[1:] != to_append.dense_shape[1:]:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Unable to concatenate %s and %s. The inner dense shapes do not '
|
|
||||||
'match inner dimensions (%s vs %s)' %
|
|
||||||
(target, to_append, target.dense_shape[1:], to_append.dense_shape[1:]))
|
|
||||||
|
|
||||||
# Add the to_append indices to target, updating the 0th value, and keeping
|
|
||||||
# track of the maximum so we know the final dense_shape of this tensor.
|
|
||||||
base_dim0_value = target.dense_shape[0]
|
|
||||||
max_dim0_value = target.dense_shape[0]
|
|
||||||
new_indices = target.indices
|
|
||||||
for index in to_append.indices:
|
|
||||||
# Here, we iterate through the sparse indices of the tensor to append. For
|
|
||||||
# each index, we update its zeroth value (the batch index) by adding the
|
|
||||||
# number of batch items in the tensor we are appending to (so an index
|
|
||||||
# of [0, 0, 1] for a value that is being appended to a tensor with 0th dim
|
|
||||||
# size 3 would become [3, 0, 1].)
|
|
||||||
index[0] += base_dim0_value
|
|
||||||
max_dim0_value = max(max_dim0_value, index[0])
|
|
||||||
new_indices = np.append(new_indices, [index], axis=0)
|
|
||||||
|
|
||||||
# Extend the values array to contain all of the appended values. These will
|
|
||||||
# be in the same order as the indices added above.
|
|
||||||
new_values = np.concatenate((target.values, to_append.values), axis=0)
|
|
||||||
|
|
||||||
# Create a new dense shape by replacing the value for the 0th dimension
|
|
||||||
# with the new max dim0 value.
|
|
||||||
new_dense_shape = list(target.dense_shape)
|
|
||||||
new_dense_shape[0] = max_dim0_value + 1
|
|
||||||
new_dense_shape = tuple(new_dense_shape)
|
|
||||||
|
|
||||||
return sparse_tensor.SparseTensorValue(
|
|
||||||
indices=new_indices, values=new_values, dense_shape=new_dense_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def _append_ragged_tensor_value(target, to_append):
|
|
||||||
"""Append ragged tensor value objects."""
|
|
||||||
# Make sure the ragged tensors are of the same size (save for the 0th dim).
|
|
||||||
if len(target.shape) != len(to_append.shape):
|
|
||||||
raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
|
|
||||||
|
|
||||||
if target.shape[1:] != to_append.shape[1:]:
|
|
||||||
raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
|
|
||||||
|
|
||||||
adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
|
|
||||||
new_row_splits = np.append(target.row_splits, adjusted_row_splits)
|
|
||||||
if isinstance(target.values, ragged_tensor_value.RaggedTensorValue):
|
|
||||||
new_values = _append_ragged_tensor_value(target.values, to_append.values)
|
|
||||||
else:
|
|
||||||
new_values = np.concatenate((target.values, to_append.values), axis=0)
|
|
||||||
|
|
||||||
return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits)
|
|
||||||
|
|
||||||
|
|
||||||
def append_composite_tensor(target, to_append):
|
|
||||||
"""Helper function to append composite tensors to each other in the 0 axis.
|
|
||||||
|
|
||||||
In order to support batching within a fit/evaluate/predict call, we need
|
|
||||||
to be able to aggregate within a CompositeTensor. Unfortunately, the CT
|
|
||||||
API currently does not make this easy - especially in V1 mode, where we're
|
|
||||||
working with CompositeTensor Value objects that have no connection with the
|
|
||||||
CompositeTensors that created them.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
target: CompositeTensor or CompositeTensor value object that will be
|
|
||||||
appended to.
|
|
||||||
to_append: CompositeTensor or CompositeTensor value object to append to.
|
|
||||||
'target'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A CompositeTensor or CompositeTensor value object.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: if concatenation is not possible.
|
|
||||||
"""
|
|
||||||
if type(target) is not type(to_append):
|
|
||||||
raise RuntimeError('Unable to concatenate %s and %s' %
|
|
||||||
(type(target), type(to_append)))
|
|
||||||
|
|
||||||
# Perform type-specific concatenation.
|
|
||||||
# TODO(b/125094323): This should be replaced by a simple call to
|
|
||||||
# target.append() that should work on all of the below classes.
|
|
||||||
|
|
||||||
# If we're seeing a CompositeTensor here, we know it's because we're in
|
|
||||||
# Eager mode (or else we'd have evaluated the CT to a CT Value object
|
|
||||||
# already). Therefore, it's safe to call concat() on it without evaluating
|
|
||||||
# the result any further. If not - that is, if we're seeing a
|
|
||||||
# SparseTensorValue or a RaggedTensorValue - we need to hand-update it
|
|
||||||
# since we're outside of the graph anyways.
|
|
||||||
if isinstance(target, sparse_tensor.SparseTensor):
|
|
||||||
# We need to invoke the sparse version of concatenate here - tf.concat
|
|
||||||
# won't work.
|
|
||||||
return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0)
|
|
||||||
elif isinstance(target, ragged_tensor.RaggedTensor):
|
|
||||||
return ragged_concat_ops.concat([target, to_append], axis=0)
|
|
||||||
elif isinstance(target, sparse_tensor.SparseTensorValue):
|
|
||||||
return _append_sparse_tensor_value(target, to_append)
|
|
||||||
elif isinstance(target, ragged_tensor_value.RaggedTensorValue):
|
|
||||||
return _append_ragged_tensor_value(target, to_append)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('Attempted to concatenate unsupported object %s.' %
|
|
||||||
type(target))
|
|
@ -1,103 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for tensorflow.python.framework.composite_tensor_utils."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.python.framework import composite_tensor_utils
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import sparse_tensor
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
|
||||||
from tensorflow.python.platform import googletest
|
|
||||||
|
|
||||||
|
|
||||||
class CompositeTensorTest(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
def test_is_composite(self):
|
|
||||||
# Validate that all composite tensor and value types return true.
|
|
||||||
self.assertTrue(
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(
|
|
||||||
sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1])))
|
|
||||||
self.assertTrue(
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(
|
|
||||||
sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])))
|
|
||||||
self.assertTrue(
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(
|
|
||||||
ragged_tensor.RaggedTensor.from_row_splits(
|
|
||||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))))
|
|
||||||
self.assertTrue(
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(
|
|
||||||
ragged_tensor_value.RaggedTensorValue(
|
|
||||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))))
|
|
||||||
|
|
||||||
# Test that numpy arrays and tensors return false.
|
|
||||||
self.assertFalse(
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(
|
|
||||||
np.ndarray([0, 1])))
|
|
||||||
self.assertFalse(
|
|
||||||
composite_tensor_utils.is_composite_or_composite_value(
|
|
||||||
ops.convert_to_tensor([3, 1])))
|
|
||||||
|
|
||||||
def test_sparse_concatenation(self):
|
|
||||||
tensor_1 = sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1])
|
|
||||||
tensor_2 = sparse_tensor.SparseTensor([[0, 0]], [2], [1, 1])
|
|
||||||
concatenated_tensor = composite_tensor_utils.append_composite_tensor(
|
|
||||||
tensor_1, tensor_2)
|
|
||||||
evaluated_tensor = self.evaluate(concatenated_tensor)
|
|
||||||
self.assertAllEqual(evaluated_tensor.indices, [[0, 0], [1, 0]])
|
|
||||||
self.assertAllEqual(evaluated_tensor.values, [1, 2])
|
|
||||||
self.assertAllEqual(evaluated_tensor.dense_shape, [2, 1])
|
|
||||||
|
|
||||||
def test_sparse_value_concatenation(self):
|
|
||||||
tensor_1 = sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])
|
|
||||||
tensor_2 = sparse_tensor.SparseTensorValue([[0, 0]], [2], [1, 1])
|
|
||||||
concatenated_tensor = composite_tensor_utils.append_composite_tensor(
|
|
||||||
tensor_1, tensor_2)
|
|
||||||
self.assertAllEqual(concatenated_tensor.indices, [[0, 0], [1, 0]])
|
|
||||||
self.assertAllEqual(concatenated_tensor.values, [1, 2])
|
|
||||||
self.assertAllEqual(concatenated_tensor.dense_shape, [2, 1])
|
|
||||||
|
|
||||||
def test_ragged_concatenation(self):
|
|
||||||
tensor_1 = ragged_tensor.RaggedTensor.from_row_splits(
|
|
||||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
|
|
||||||
tensor_2 = ragged_tensor.RaggedTensor.from_row_splits(
|
|
||||||
np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
|
|
||||||
concatenated_tensor = composite_tensor_utils.append_composite_tensor(
|
|
||||||
tensor_1, tensor_2)
|
|
||||||
evaluated_tensor = self.evaluate(concatenated_tensor)
|
|
||||||
|
|
||||||
self.assertAllEqual(evaluated_tensor.values, [0, 1, 2, 3, 4, 5])
|
|
||||||
self.assertAllEqual(evaluated_tensor.row_splits, [0, 1, 3, 5, 6])
|
|
||||||
|
|
||||||
def test_ragged_value_concatenation(self):
|
|
||||||
tensor_1 = ragged_tensor_value.RaggedTensorValue(
|
|
||||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
|
|
||||||
tensor_2 = ragged_tensor_value.RaggedTensorValue(
|
|
||||||
np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
|
|
||||||
concatenated_tensor = composite_tensor_utils.append_composite_tensor(
|
|
||||||
tensor_1, tensor_2)
|
|
||||||
|
|
||||||
self.assertAllEqual(concatenated_tensor.values, [0, 1, 2, 3, 4, 5])
|
|
||||||
self.assertAllEqual(concatenated_tensor.row_splits, [0, 1, 3, 5, 6])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
googletest.main()
|
|
@ -61,7 +61,6 @@ py_library(
|
|||||||
"//tensorflow/python:check_ops",
|
"//tensorflow/python:check_ops",
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:clip_ops",
|
"//tensorflow/python:clip_ops",
|
||||||
"//tensorflow/python:composite_tensor_utils",
|
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:ctc_ops",
|
"//tensorflow/python:ctc_ops",
|
||||||
|
@ -47,7 +47,6 @@ py_library(
|
|||||||
":input_spec",
|
":input_spec",
|
||||||
":keras_tensor",
|
":keras_tensor",
|
||||||
":node",
|
":node",
|
||||||
"//tensorflow/python:composite_tensor_utils",
|
|
||||||
"//tensorflow/python:py_checkpoint_reader",
|
"//tensorflow/python:py_checkpoint_reader",
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
"//tensorflow/python/distribute:distribute_coordinator",
|
"//tensorflow/python/distribute:distribute_coordinator",
|
||||||
@ -221,7 +220,6 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":base_layer",
|
":base_layer",
|
||||||
"//tensorflow/python:composite_tensor_utils",
|
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
"//tensorflow/python/eager:monitoring",
|
"//tensorflow/python/eager:monitoring",
|
||||||
"//tensorflow/python/keras:backend",
|
"//tensorflow/python/keras:backend",
|
||||||
|
@ -30,6 +30,8 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.data.ops import readers
|
from tensorflow.python.data.ops import readers
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
@ -37,6 +39,8 @@ from tensorflow.python.keras import testing_utils
|
|||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
from tensorflow.python.keras.engine import training_utils_v1
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
@ -392,5 +396,74 @@ class AggregationTest(keras_parameterized.TestCase):
|
|||||||
self._run_without_steps()
|
self._run_without_steps()
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeTensorTestUtils(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_is_composite(self):
|
||||||
|
# Validate that all composite tensor and value types return true.
|
||||||
|
self.assertTrue(
|
||||||
|
training_utils_v1.is_composite_or_composite_value(
|
||||||
|
sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1])))
|
||||||
|
self.assertTrue(
|
||||||
|
training_utils_v1.is_composite_or_composite_value(
|
||||||
|
sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])))
|
||||||
|
self.assertTrue(
|
||||||
|
training_utils_v1.is_composite_or_composite_value(
|
||||||
|
ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))))
|
||||||
|
self.assertTrue(
|
||||||
|
training_utils_v1.is_composite_or_composite_value(
|
||||||
|
ragged_tensor_value.RaggedTensorValue(
|
||||||
|
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))))
|
||||||
|
|
||||||
|
# Test that numpy arrays and tensors return false.
|
||||||
|
self.assertFalse(
|
||||||
|
training_utils_v1.is_composite_or_composite_value(np.ndarray([0, 1])))
|
||||||
|
self.assertFalse(
|
||||||
|
training_utils_v1.is_composite_or_composite_value(
|
||||||
|
ops.convert_to_tensor_v2_with_dispatch([3, 1])))
|
||||||
|
|
||||||
|
def test_sparse_concatenation(self):
|
||||||
|
tensor_1 = sparse_tensor.SparseTensor([[0, 0]], [1], [1, 1])
|
||||||
|
tensor_2 = sparse_tensor.SparseTensor([[0, 0]], [2], [1, 1])
|
||||||
|
concatenated_tensor = training_utils_v1._append_composite_tensor(
|
||||||
|
tensor_1, tensor_2)
|
||||||
|
evaluated_tensor = self.evaluate(concatenated_tensor)
|
||||||
|
self.assertAllEqual(evaluated_tensor.indices, [[0, 0], [1, 0]])
|
||||||
|
self.assertAllEqual(evaluated_tensor.values, [1, 2])
|
||||||
|
self.assertAllEqual(evaluated_tensor.dense_shape, [2, 1])
|
||||||
|
|
||||||
|
def test_sparse_value_concatenation(self):
|
||||||
|
tensor_1 = sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])
|
||||||
|
tensor_2 = sparse_tensor.SparseTensorValue([[0, 0]], [2], [1, 1])
|
||||||
|
concatenated_tensor = training_utils_v1._append_composite_tensor(
|
||||||
|
tensor_1, tensor_2)
|
||||||
|
self.assertAllEqual(concatenated_tensor.indices, [[0, 0], [1, 0]])
|
||||||
|
self.assertAllEqual(concatenated_tensor.values, [1, 2])
|
||||||
|
self.assertAllEqual(concatenated_tensor.dense_shape, [2, 1])
|
||||||
|
|
||||||
|
def test_ragged_concatenation(self):
|
||||||
|
tensor_1 = ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
|
||||||
|
tensor_2 = ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
|
||||||
|
concatenated_tensor = training_utils_v1._append_composite_tensor(
|
||||||
|
tensor_1, tensor_2)
|
||||||
|
evaluated_tensor = self.evaluate(concatenated_tensor)
|
||||||
|
|
||||||
|
self.assertAllEqual(evaluated_tensor.values, [0, 1, 2, 3, 4, 5])
|
||||||
|
self.assertAllEqual(evaluated_tensor.row_splits, [0, 1, 3, 5, 6])
|
||||||
|
|
||||||
|
def test_ragged_value_concatenation(self):
|
||||||
|
tensor_1 = ragged_tensor_value.RaggedTensorValue(
|
||||||
|
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
|
||||||
|
tensor_2 = ragged_tensor_value.RaggedTensorValue(
|
||||||
|
np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
|
||||||
|
concatenated_tensor = training_utils_v1._append_composite_tensor(
|
||||||
|
tensor_1, tensor_2)
|
||||||
|
|
||||||
|
self.assertAllEqual(concatenated_tensor.values, [0, 1, 2, 3, 4, 5])
|
||||||
|
self.assertAllEqual(concatenated_tensor.row_splits, [0, 1, 3, 5, 6])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user