Adds utility functions for dealing with TensorLists in XLA (No functional change to tensor list kernels)

These will be useful in a follow-up change on lazy initialization of lists passed to XlaWhile.

PiperOrigin-RevId: 236703345
This commit is contained in:
Saurabh Saxena 2019-03-04 12:31:04 -08:00 committed by TensorFlower Gardener
parent fe8bf74412
commit b7628a22f0
5 changed files with 270 additions and 75 deletions

View File

@ -171,7 +171,7 @@ class ListOpsTest(xla_test.XLATestCase):
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
# Pushing an element with a different shape should raise an error. # Pushing an element with a different shape should raise an error.
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): with self.assertRaisesRegexp(errors.InternalError, "shape"):
l = list_ops.tensor_list_push_back(l, 5.) l = list_ops.tensor_list_push_back(l, 5.)
self.evaluate( self.evaluate(
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))

View File

@ -120,6 +120,7 @@ tf_kernel_library(
":case_op", ":case_op",
":conv_op_helpers", ":conv_op_helpers",
":if_op", ":if_op",
":tensor_list_utils",
":while_op", ":while_op",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
@ -234,6 +235,23 @@ cc_library(
], ],
) )
cc_library(
name = "tensor_list_utils",
srcs = ["tensor_list_utils.cc"],
hdrs = ["tensor_list_utils.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library( tf_kernel_library(
name = "while_op", name = "while_op",
srcs = ["while_op.cc"], srcs = ["while_op.cc"],

View File

@ -14,13 +14,11 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// XLA TensorList operators. // XLA TensorList operators.
// Tensor lists are represented as tuple consisting of a pre-allocated list
// consisting of the tensors (and where dim 0 is the list index), along with a
// scalar telling us the current number of elements.
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -28,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
@ -40,27 +39,16 @@ limitations under the License.
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
namespace {
Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, namespace {
TensorShape* tensor_list_shape) {
auto shape_or_status = builder->GetShape(op);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
xla::Shape shape = shape_or_status.ValueOrDie();
TF_RET_CHECK(shape.IsTuple());
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
tensor_list_shape);
}
class TensorListLengthOp : public XlaOpKernel { class TensorListLengthOp : public XlaOpKernel {
public: public:
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp tl = ctx->Input(0); xla::XlaOp index;
xla::XlaOp index = xla::GetTupleElement(tl, 1); OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index));
ctx->SetOutput(0, index); ctx->SetOutput(0, index);
} }
@ -117,12 +105,15 @@ class TensorListReserveOp : public XlaOpKernel {
int64 num_elements; int64 num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
xla::XlaOp list; xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list)); OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &buffer));
xla::XlaBuilder* b = ctx->builder(); xla::XlaOp output_list;
ctx->SetTensorListOutput( OP_REQUIRES_OK(
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, num_elements)})); ctx, BuildTensorList(
buffer, xla::ConstantR0<int32>(ctx->builder(), num_elements),
&output_list));
ctx->SetTensorListOutput(0, output_list);
} }
private: private:
@ -150,13 +141,15 @@ class EmptyTensorListOp : public XlaOpKernel {
errors::InvalidArgument("XLA compilation requires a fixed tensor list " errors::InvalidArgument("XLA compilation requires a fixed tensor list "
"size. Set the max number of elements.")); "size. Set the max number of elements."));
xla::XlaOp list; xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
CreateZerosList(ctx, 0, max_num_elements, dtype_, &list)); CreateZerosList(ctx, 0, max_num_elements, dtype_, &buffer));
xla::XlaBuilder* b = ctx->builder(); xla::XlaOp output_list;
ctx->SetTensorListOutput( OP_REQUIRES_OK(
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, 0)})); ctx, BuildTensorList(buffer, xla::ConstantR0<int32>(ctx->builder(), 0),
&output_list));
ctx->SetTensorListOutput(0, output_list);
} }
private: private:
@ -180,7 +173,7 @@ class TensorListElementShapeOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
TensorShape shape; TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
shape.RemoveDim(0); shape.RemoveDim(0);
switch (shape_type_) { switch (shape_type_) {
@ -221,9 +214,10 @@ class TensorListGetItemOp : public XlaOpKernel {
xla::XlaOp state = ctx->Input(0); xla::XlaOp state = ctx->Input(0);
TensorShape shape; TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
xla::XlaOp ta = xla::GetTupleElement(state, 0); xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &buffer));
xla::XlaOp index = ctx->Input(1); xla::XlaOp index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
@ -233,7 +227,7 @@ class TensorListGetItemOp : public XlaOpKernel {
auto slice_shape = shape.dim_sizes(); auto slice_shape = shape.dim_sizes();
slice_shape[0] = 1LL; slice_shape[0] = 1LL;
xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); xla::XlaOp read = xla::DynamicSlice(buffer, start_indices, slice_shape);
// Remove the leading '1' dimension. // Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end()); std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
@ -255,9 +249,9 @@ class TensorListStackOp : public XlaOpKernel {
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp state = ctx->Input(0); xla::XlaOp buffer;
xla::XlaOp ta = xla::GetTupleElement(state, 0); OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
ctx->SetOutput(0, ta); ctx->SetOutput(0, buffer);
} }
private: private:
@ -289,8 +283,11 @@ class TensorListFromTensorOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
const xla::XlaOp tensor = ctx->Input(0); const xla::XlaOp tensor = ctx->Input(0);
ctx->SetTensorListOutput( xla::XlaOp output_list;
0, xla::Tuple(b, {tensor, xla::ConstantR0<int32>(b, num_elements)})); OP_REQUIRES_OK(
ctx, BuildTensorList(tensor, xla::ConstantR0<int32>(b, num_elements),
&output_list));
ctx->SetTensorListOutput(0, output_list);
} }
private: private:
@ -306,30 +303,31 @@ REGISTER_XLA_OP(
// Returns the 0'th element of `tuple` containing the list tensor if it has been // Returns the 0'th element of `tuple` containing the list tensor if it has been
// initialized already else creates one lazily. This allows lazy initialization // initialized already else creates one lazily. This allows lazy initialization
// of the list on the first call to SetItem or PushBack. // of the list on the first call to SetItem or PushBack.
Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple, Status GetInitializedList(const xla::XlaOp& input_list,
const TensorShape& element_shape, DataType dtype, const TensorShape& element_shape, DataType dtype,
xla::XlaOp* list) { xla::XlaOp* output_list_buffer) {
*list = xla::GetTupleElement(tuple, 0); bool is_already_initialized;
TensorShape list_shape; TF_RETURN_IF_ERROR(
TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape)); IsTensorListInitialized(input_list, &is_already_initialized));
int64 leading_dim = list_shape.dim_size(0); TensorShape input_list_shape;
TensorShape list_element_shape = list_shape; TF_RETURN_IF_ERROR(GetTensorListBufferShape(input_list, &input_list_shape));
list_element_shape.RemoveDim(0); TensorShape input_list_element_shape = input_list_shape;
// This checks for the lazy initialization contract set by CreateEmptyList. input_list_element_shape.RemoveDim(0);
// In TensorListReserve if the element_shape is not known at compile time,
// it creates a list with shape [leading_dim, 0]. if (is_already_initialized) {
if (element_shape != list_element_shape) { TF_RET_CHECK(element_shape == input_list_element_shape);
if (list_element_shape.num_elements() != 0) { TF_RETURN_IF_ERROR(GetTensorListBuffer(input_list, output_list_buffer));
return errors::InvalidArgument( return Status::OK();
"Invalid shape of value in TensorListSetItem. Expected: ",
list_element_shape.DebugString(),
" Actual: ", element_shape.DebugString());
}
list_shape = element_shape;
list_shape.InsertDim(0, leading_dim);
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
list_shape.dim_sizes());
} }
int64 leading_dim = input_list_shape.dim_size(0);
TensorShape output_list_shape = element_shape;
output_list_shape.InsertDim(0, leading_dim);
xla::XlaOp output_list;
TF_RETURN_IF_ERROR(
InitializeTensorList(input_list, output_list_shape, &output_list));
TF_RETURN_IF_ERROR(GetTensorListBuffer(output_list, output_list_buffer));
return Status::OK(); return Status::OK();
} }
@ -344,8 +342,10 @@ class TensorListSetItemOp : public XlaOpKernel {
xla::XlaOp tl = ctx->Input(0); xla::XlaOp tl = ctx->Input(0);
TensorShape elem_shape = ctx->InputShape(2); TensorShape elem_shape = ctx->InputShape(2);
xla::XlaOp list; xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list)); OP_REQUIRES_OK(ctx, GetInitializedList(tl, elem_shape, dtype_, &buffer));
xla::XlaOp push_index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tl, &push_index));
xla::XlaOp index = ctx->Input(1); xla::XlaOp index = ctx->Input(1);
xla::XlaOp value = ctx->Input(2); xla::XlaOp value = ctx->Input(2);
@ -359,9 +359,11 @@ class TensorListSetItemOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL); slice_shape.InsertDim(0, 1LL);
auto update = xla::Reshape(value, slice_shape.dim_sizes()); auto update = xla::Reshape(value, slice_shape.dim_sizes());
ctx->SetTensorListOutput( xla::XlaOp output_list;
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), OP_REQUIRES_OK(ctx, BuildTensorList(xla::DynamicUpdateSlice(buffer, update,
xla::GetTupleElement(tl, 1)})); start_indices),
push_index, &output_list));
ctx->SetTensorListOutput(0, output_list);
} }
private: private:
@ -383,11 +385,12 @@ class TensorListPushBackOp : public XlaOpKernel {
xla::XlaOp list_tuple = ctx->Input(0); xla::XlaOp list_tuple = ctx->Input(0);
TensorShape elem_shape = ctx->InputShape(1); TensorShape elem_shape = ctx->InputShape(1);
xla::XlaOp list; xla::XlaOp buffer;
OP_REQUIRES_OK( OP_REQUIRES_OK(ctx,
ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list)); GetInitializedList(list_tuple, elem_shape, dtype_, &buffer));
xla::XlaOp index = xla::GetTupleElement(list_tuple, 1); xla::XlaOp index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list_tuple, &index));
xla::XlaOp value = ctx->Input(1); xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
@ -399,9 +402,12 @@ class TensorListPushBackOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL); slice_shape.InsertDim(0, 1LL);
auto update = xla::Reshape(value, slice_shape.dim_sizes()); auto update = xla::Reshape(value, slice_shape.dim_sizes());
ctx->SetTensorListOutput( xla::XlaOp output_list;
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), OP_REQUIRES_OK(
index + xla::ConstantR0<int32>(b, 1)})); ctx,
BuildTensorList(xla::DynamicUpdateSlice(buffer, update, start_indices),
index + xla::ConstantR0<int32>(b, 1), &output_list));
ctx->SetTensorListOutput(0, output_list);
} }
private: private:
@ -423,10 +429,12 @@ class TensorListPopBackOp : public XlaOpKernel {
xla::XlaOp state = ctx->Input(0); xla::XlaOp state = ctx->Input(0);
TensorShape shape; TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
xla::XlaOp ta = xla::GetTupleElement(state, 0); xla::XlaOp ta;
xla::XlaOp index = xla::GetTupleElement(state, 1); OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &ta));
xla::XlaOp index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(state, &index));
index = index - xla::ConstantR0<int32>(b, 1); index = index - xla::ConstantR0<int32>(b, 1);
@ -441,7 +449,9 @@ class TensorListPopBackOp : public XlaOpKernel {
// Remove the leading '1' dimension. // Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end()); std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index})); xla::XlaOp output_list;
OP_REQUIRES_OK(ctx, BuildTensorList(ta, index, &output_list));
ctx->SetTensorListOutput(0, output_list);
ctx->SetOutput(1, xla::Reshape(read, value_shape)); ctx->SetOutput(1, xla::Reshape(read, value_shape));
} }

View File

@ -0,0 +1,100 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
}
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
xla::XlaOp* output_list) {
TF_RET_CHECK(buffer.builder());
*output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
return Status::OK();
}
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
TF_RET_CHECK(op.builder());
*buffer = xla::GetTupleElement(op, 0);
return Status::OK();
}
Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) {
TF_RET_CHECK(op.builder());
*push_index = xla::GetTupleElement(op, 1);
return Status::OK();
}
Status GetTensorListBufferShape(const xla::XlaOp& op,
TensorShape* buffer_shape) {
TF_RET_CHECK(op.builder());
TensorShape shape;
TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
op.builder()->GetShape(op));
return GetTensorListBufferShape(list_tuple_shape, buffer_shape);
}
Status GetTensorListBufferShape(const xla::Shape& list_shape,
TensorShape* buffer_shape) {
TF_RET_CHECK(list_shape.IsTuple());
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape));
return Status::OK();
}
Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) {
TensorShape list_shape;
TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape));
*is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0);
return Status::OK();
}
Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
const TensorShape& buffer_shape,
xla::XlaOp* output_list) {
TensorShape input_buffer_shape;
TF_RETURN_IF_ERROR(
GetTensorListBufferShape(uninitialized_list, &input_buffer_shape));
if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) {
return errors::InvalidArgument(
"Number of elements in input list does not match buffer size. ",
"input list size: ", input_buffer_shape.dim_size(0),
"buffer size: ", buffer_shape.dim_size(0));
}
xla::XlaBuilder* builder = uninitialized_list.builder();
xla::XlaOp input_buffer;
TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer));
TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape,
builder->GetShape(input_buffer));
auto new_buffer = xla::Broadcast(
xla::ConstantLiteral(builder, xla::LiteralUtil::Zero(
input_buffer_xla_shape.element_type())),
buffer_shape.dim_sizes());
xla::XlaOp push_index;
TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index));
return BuildTensorList(new_buffer, push_index, output_list);
}
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
// TensorList utilities.
//
// Tensor lists are represented as tuple consisting of a pre-allocated buffer
// consisting of the tensors (and where dim 0 is the list index), along with a
// scalar telling us the next index to push a value at.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
// Whether the input expression at `index` corresponds to a TensorList.
bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
// Builds a TensorList from its constituents, `buffer` and `push_index`.
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
xla::XlaOp* output_list);
// Returns the buffer for the TensorList.
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
// Returns the push_index for the TensorList.
Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index);
// Returns the shape of the TensorList buffer.
Status GetTensorListBufferShape(const xla::XlaOp& op,
TensorShape* buffer_shape);
// Inputs the TensorList shape and returns the buffer shape.
Status GetTensorListBufferShape(const xla::Shape& list_shape,
TensorShape* buffer_shape);
// Returns whether the TensorList has been initialized.
//
// A TensorList is considered initialized if its element_shape is completely
// known.
Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized);
// Inputs an uninitialized list and a buffer_shape and returns an initialized
// list. The initialized list uses the dtype and push index of the uninitialized
// list and is filled with zeros.
Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
const TensorShape& buffer_shape,
xla::XlaOp* output_list);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_