Add TensorSpec support for CompositeTensors (such as SparseTensor and RaggedTensor).
PiperOrigin-RevId: 245272144
This commit is contained in:
parent
6b8c6cb57f
commit
ccd35296e8
@ -699,11 +699,13 @@ py_library(
|
||||
":framework_for_generated_wrappers",
|
||||
":function",
|
||||
":graph_util",
|
||||
":indexed_slices_tensor_spec",
|
||||
":lib",
|
||||
":platform",
|
||||
":pywrap_tensorflow",
|
||||
":random_seed",
|
||||
":sparse_tensor",
|
||||
":sparse_tensor_spec",
|
||||
":tensor_spec",
|
||||
":tensor_util",
|
||||
":util",
|
||||
@ -1191,6 +1193,30 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sparse_tensor_spec",
|
||||
srcs = ["framework/sparse_tensor_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dtypes",
|
||||
":sparse_tensor",
|
||||
":tensor_shape",
|
||||
":tensor_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "indexed_slices_tensor_spec",
|
||||
srcs = ["framework/indexed_slices_tensor_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":sparse_tensor",
|
||||
":tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tensor_util",
|
||||
srcs = ["framework/tensor_util.py"],
|
||||
|
@ -1039,7 +1039,7 @@ class FunctionSpec(object):
|
||||
|
||||
self._input_signature = tuple(input_signature)
|
||||
self._flat_input_signature = tuple(nest.flatten(input_signature,
|
||||
expand_composites=False))
|
||||
expand_composites=True))
|
||||
|
||||
@property
|
||||
def fullargspec(self):
|
||||
@ -1656,7 +1656,7 @@ def register(func, *args, **kwargs):
|
||||
|
||||
def validate_signature(signature):
|
||||
if any(not isinstance(arg, tensor_spec.TensorSpec)
|
||||
for arg in nest.flatten(signature, expand_composites=False)):
|
||||
for arg in nest.flatten(signature, expand_composites=True)):
|
||||
raise TypeError("Invalid input_signature %s; input_signature must be "
|
||||
"a possibly nested sequence of TensorSpec objects.")
|
||||
|
||||
|
@ -41,6 +41,7 @@ from tensorflow.python.framework import function as tf_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import sparse_tensor_spec
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_ops
|
||||
@ -66,6 +67,7 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_spec
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.util import compat
|
||||
@ -964,6 +966,16 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
|
||||
(sparse_tensor.SparseTensor,
|
||||
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
|
||||
(ragged_tensor.RaggedTensor.from_row_lengths,
|
||||
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
|
||||
[ragged_tensor_spec.ragged_tensor_spec([None, None], dtypes.int32)]),
|
||||
(ragged_tensor.RaggedTensor.from_nested_row_lengths,
|
||||
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
|
||||
[ragged_tensor_spec.ragged_tensor_spec([None, None, None],
|
||||
dtypes.int32)]),
|
||||
(sparse_tensor.SparseTensor,
|
||||
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
|
||||
[sparse_tensor_spec.sparse_tensor_spec([None], dtypes.int32)]),
|
||||
]) # pyformat: disable
|
||||
def testCompositeAsArgumentTensorWithDefun(self,
|
||||
factory_fn,
|
||||
|
@ -27,7 +27,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -79,7 +78,6 @@ def convert_structure_to_signature(structure, arg_names=None):
|
||||
Identical structure that has TensorSpec objects instead of Tensors and
|
||||
UknownArgument instead of any unsupported types.
|
||||
"""
|
||||
structure = composite_tensor.replace_composites_with_components(structure)
|
||||
def encode_arg(arg, path):
|
||||
"""A representation for this argument, for converting into signatures."""
|
||||
if isinstance(arg, ops.Tensor):
|
||||
|
65
tensorflow/python/framework/indexed_slices_tensor_spec.py
Normal file
65
tensorflow/python/framework/indexed_slices_tensor_spec.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorSpec factory for sparse tensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
|
||||
|
||||
def indexed_slices_tensor_spec(shape=None,
|
||||
dtype=dtypes.float32,
|
||||
num_slices=None,
|
||||
has_dense_shape=True,
|
||||
name=None):
|
||||
"""Returns a tensor specification for a IndexedSlices.
|
||||
|
||||
Returns an object which can be passed to `tf.function` (or other
|
||||
functions that expect `TensorSpec`s) to specify shape constraints
|
||||
for a `IndexedSlices` argument.
|
||||
|
||||
Args:
|
||||
shape: The shape of the IndexedSlices, or `None` to allow any shape.
|
||||
The returned specification object depends only on `shape[1:]`.
|
||||
dtype: Data type of values in the IndexedSlices.
|
||||
num_slices: Number of slices. Default allows for any number of slices.
|
||||
has_dense_shape: Indicates whether the IndexedSlices is expected to have a
|
||||
`dense_shape` component.
|
||||
name: Optional name prefix for the `TensorSpec`s.
|
||||
|
||||
Returns:
|
||||
An object describing the `values`, `indices` and `dense_shape` tensors
|
||||
that comprise the `IndexedSlices`.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
shape = tensor_shape.TensorShape(shape)
|
||||
num_slices = tensor_shape.Shape([num_slices])
|
||||
|
||||
values = tensor_spec.TensorSpec(
|
||||
num_slices.concatenate(shape[1:]), dtype, name)
|
||||
indices = tensor_spec.TensorSpec(num_slices, dtypes.int64,
|
||||
("%s.indices" % name) if name else None)
|
||||
if has_dense_shape:
|
||||
dense_shape = tensor_spec.TensorSpec([shape.ndims], dtypes.int64,
|
||||
("%s.dense_shape" %
|
||||
name) if name else None)
|
||||
else:
|
||||
dense_shape = None
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
@ -62,8 +62,14 @@ from tensorflow.python.util import memory
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_stack
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# This is to avoid a circular dependency: ops -> tensor_spec -> ops
|
||||
tensor_spec = LazyLoader(
|
||||
"tensor_spec", globals(),
|
||||
"tensorflow.python.framework.tensor_spec")
|
||||
|
||||
# Temporary global switches determining if we should enable the work-in-progress
|
||||
# calls to the C API. These will be removed once all functionality is supported.
|
||||
_USE_C_API = True
|
||||
@ -1686,7 +1692,8 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
|
||||
|
||||
def __init__(self, values, indices, dense_shape=None):
|
||||
"""Creates an `IndexedSlices`."""
|
||||
_get_graph_from_inputs([values, indices, dense_shape])
|
||||
if not isinstance(values, tensor_spec.TensorSpec):
|
||||
_get_graph_from_inputs([values, indices, dense_shape])
|
||||
self._values = values
|
||||
self._indices = indices
|
||||
self._dense_shape = dense_shape
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -114,21 +115,29 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
||||
values: A 1-D tensor of any type and shape `[N]`.
|
||||
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
|
||||
"""
|
||||
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
|
||||
indices = ops.convert_to_tensor(
|
||||
indices, name="indices", dtype=dtypes.int64)
|
||||
# TODO(touts): Consider adding mutable_values() when 'values'
|
||||
# is a VariableOp and updating users of SparseTensor.
|
||||
values = ops.internal_convert_to_tensor(values, name="values")
|
||||
dense_shape = ops.convert_to_tensor(
|
||||
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
||||
if isinstance(indices, tensor_spec.TensorSpec):
|
||||
if not isinstance(values, tensor_spec.TensorSpec):
|
||||
raise TypeError("Expected values to be a TensorSpec")
|
||||
if not isinstance(dense_shape, tensor_spec.TensorSpec):
|
||||
raise TypeError("Expected dense_shape to be a TensorSpec")
|
||||
if indices.dtype != dtypes.int64 or dense_shape.dtype != dtypes.int64:
|
||||
raise TypeError("indices and dense_shape must have dtype=int64")
|
||||
else:
|
||||
with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
|
||||
indices = ops.convert_to_tensor(
|
||||
indices, name="indices", dtype=dtypes.int64)
|
||||
# TODO(touts): Consider adding mutable_values() when 'values'
|
||||
# is a VariableOp and updating users of SparseTensor.
|
||||
values = ops.internal_convert_to_tensor(values, name="values")
|
||||
dense_shape = ops.convert_to_tensor(
|
||||
dense_shape, name="dense_shape", dtype=dtypes.int64)
|
||||
self._indices = indices
|
||||
self._values = values
|
||||
self._dense_shape = dense_shape
|
||||
|
||||
indices_shape = indices.get_shape().with_rank(2)
|
||||
values_shape = values.get_shape().with_rank(1)
|
||||
dense_shape_shape = dense_shape.get_shape().with_rank(1)
|
||||
indices_shape = indices.shape.with_rank(2)
|
||||
values_shape = values.shape.with_rank(1)
|
||||
dense_shape_shape = dense_shape.shape.with_rank(1)
|
||||
|
||||
# Assert number of rows in indices match the number of elements in values.
|
||||
indices_shape.dims[0].merge_with(values_shape.dims[0])
|
||||
@ -244,8 +253,8 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
||||
return [
|
||||
tensor_shape.TensorShape([None, rank]), # indices
|
||||
tensor_shape.TensorShape([None]), # values
|
||||
tensor_shape.TensorShape([rank])
|
||||
] # dense_shape
|
||||
tensor_shape.TensorShape([rank]) # dense_shape
|
||||
]
|
||||
|
||||
@property
|
||||
def _is_graph_tensor(self):
|
||||
|
56
tensorflow/python/framework/sparse_tensor_spec.py
Normal file
56
tensorflow/python/framework/sparse_tensor_spec.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorSpec factory for sparse tensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
|
||||
|
||||
def sparse_tensor_spec(shape=None,
|
||||
dtype=dtypes.float32,
|
||||
num_values=None,
|
||||
name=None):
|
||||
"""Returns a tensor specification for a SparseTensor.
|
||||
|
||||
Returns an object which can be passed to `tf.function` (or other
|
||||
functions that expect `TensorSpec`s) to specify shape constraints
|
||||
for a `SparseTensor` argument.
|
||||
|
||||
Args:
|
||||
shape: The shape of the SparseTensor, or `None` to allow any shape. The
|
||||
returned specification object depends only on `shape.ndims`.
|
||||
dtype: Data type of values in the SparseTensor.
|
||||
num_values: The number of values in the SparseTensor, or `None` to allow any
|
||||
number of values.
|
||||
name: Optional name prefix for the `TensorSpec`s.
|
||||
|
||||
Returns:
|
||||
An object describing the `values`, `indices` and `dense_shape` tensors
|
||||
that comprise the `SparseTensor`.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
rank = tensor_shape.TensorShape(shape).rank
|
||||
indices = tensor_spec.TensorSpec([num_values, rank], dtypes.int64,
|
||||
("%s.indices" % name) if name else None)
|
||||
values = tensor_spec.TensorSpec([num_values], dtype, name)
|
||||
dense_shape = tensor_spec.TensorSpec(
|
||||
[rank], dtypes.int64, ("%s.dense_shape" % name) if name else None)
|
||||
return sparse_tensor.SparseTensor(indices, values, dense_shape)
|
@ -40,6 +40,7 @@ py_library(
|
||||
":ragged_string_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_tensor_spec",
|
||||
":ragged_tensor_value",
|
||||
":ragged_util",
|
||||
":ragged_where_op",
|
||||
@ -434,6 +435,18 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_tensor_spec",
|
||||
srcs = ["ragged_tensor_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
],
|
||||
)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# RaggedTensor Tests
|
||||
#-------------------------------------------------------------------------------
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -241,16 +242,28 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
"of the factory methods instead (e.g., "
|
||||
"RaggedTensor.from_row_lengths())")
|
||||
|
||||
# Validate the arguments.
|
||||
if not isinstance(values, (RaggedTensor, ops.Tensor)):
|
||||
raise TypeError("values must be a Tensor or RaggedTensor.")
|
||||
if not isinstance(row_splits, ops.Tensor):
|
||||
raise TypeError("Row-partitioning argument must be a Tensor.")
|
||||
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("Row-partitioning argument must be int32 or int64")
|
||||
values.shape.with_rank_at_least(1)
|
||||
is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
|
||||
if is_tensor_spec:
|
||||
if not (isinstance(values, tensor_spec.TensorSpec) or
|
||||
(isinstance(values, RaggedTensor) and
|
||||
isinstance(values.row_splits, tensor_spec.TensorSpec))):
|
||||
raise TypeError("Expected values to be a TensorSpec, got %r" % values)
|
||||
else:
|
||||
# Validate the arguments.
|
||||
if not isinstance(row_splits, ops.Tensor):
|
||||
raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
|
||||
row_splits)
|
||||
if not isinstance(values, (RaggedTensor, ops.Tensor)):
|
||||
raise TypeError("values must be a Tensor or RaggedTensor, got %r" %
|
||||
values)
|
||||
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("Row-partitioning argument must be int32 or int64")
|
||||
|
||||
# Validate shapes & dtypes.
|
||||
row_splits.shape.assert_has_rank(1)
|
||||
row_splits.set_shape([None])
|
||||
values.shape.with_rank_at_least(1)
|
||||
if not is_tensor_spec:
|
||||
row_splits.set_shape([None])
|
||||
if isinstance(values, RaggedTensor):
|
||||
assert row_splits.dtype == values.row_splits.dtype
|
||||
|
||||
@ -431,6 +444,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
raise TypeError("validate must have type bool")
|
||||
if isinstance(row_splits, (list, tuple)) and not row_splits:
|
||||
raise ValueError("row_splits tensor may not be empty.")
|
||||
if isinstance(row_splits, tensor_spec.TensorSpec):
|
||||
return cls(values=values, row_splits=row_splits, internal=True)
|
||||
|
||||
with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
|
||||
values, row_splits = cls._convert_values_and_row_partition(
|
||||
|
77
tensorflow/python/ops/ragged/ragged_tensor_spec.py
Normal file
77
tensorflow/python/ops/ragged/ragged_tensor_spec.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorSpec factory for ragged tensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
|
||||
|
||||
def ragged_tensor_spec(shape=None, dtype=dtypes.float32,
|
||||
ragged_rank=None, row_splits_dtype=dtypes.int64,
|
||||
name=None):
|
||||
"""Returns a tensor specification for a RaggedTensor.
|
||||
|
||||
Returns an object which can be passed to `tf.function` (or other
|
||||
functions that expect `TensorSpec`s) to specify shape constraints
|
||||
for a `RaggedTensor` argument.
|
||||
|
||||
Args:
|
||||
shape: The shape of the RaggedTensor, or `None` to allow any shape.
|
||||
dtype: Data type of values in the RaggedTensor.
|
||||
ragged_rank: Python integer, the ragged rank of the RaggedTensor
|
||||
to be described. Defaults to `shape.ndims - 1`.
|
||||
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
|
||||
One of `tf.int32` or `tf.int64`.
|
||||
name: Optional name prefix for the `TensorSpec`s.
|
||||
|
||||
Returns:
|
||||
An object describing the `flat_values` and `nested_row_splits` tensors
|
||||
that comprise the `RaggedTensor`.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
shape = tensor_shape.TensorShape(shape)
|
||||
if ragged_rank is None:
|
||||
if shape.ndims is None:
|
||||
raise ValueError("Must specify ragged_rank or a shape with known rank.")
|
||||
ragged_rank = shape.ndims - 1
|
||||
elif not isinstance(ragged_rank, int):
|
||||
raise TypeError("ragged_rank must be an int")
|
||||
if ragged_rank == 0:
|
||||
return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)
|
||||
|
||||
result = tensor_spec.TensorSpec(
|
||||
tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
|
||||
dtype, name)
|
||||
|
||||
for i in range(ragged_rank - 1, 0, -1):
|
||||
splits = tensor_spec.TensorSpec(
|
||||
[None], row_splits_dtype,
|
||||
"%s.row_splits_%d" % (name, i) if name else None)
|
||||
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
|
||||
|
||||
outer_dim = tensor_shape.dimension_at_index(shape, 0)
|
||||
splits_shape = [None if outer_dim is None else outer_dim + 1]
|
||||
splits = tensor_spec.TensorSpec(
|
||||
splits_shape, row_splits_dtype,
|
||||
"%s.row_splits_0" % name if name else None)
|
||||
result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
|
||||
|
||||
return result
|
Loading…
Reference in New Issue
Block a user