Adds utility functions for dealing with TensorLists in XLA (No functional change to tensor list kernels)
These will be useful in a follow-up change on lazy initialization of lists passed to XlaWhile. PiperOrigin-RevId: 236703345
This commit is contained in:
parent
fe8bf74412
commit
b7628a22f0
@ -171,7 +171,7 @@ class ListOpsTest(xla_test.XLATestCase):
|
||||
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
|
||||
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
|
||||
# Pushing an element with a different shape should raise an error.
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"):
|
||||
with self.assertRaisesRegexp(errors.InternalError, "shape"):
|
||||
l = list_ops.tensor_list_push_back(l, 5.)
|
||||
self.evaluate(
|
||||
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
|
||||
|
@ -120,6 +120,7 @@ tf_kernel_library(
|
||||
":case_op",
|
||||
":conv_op_helpers",
|
||||
":if_op",
|
||||
":tensor_list_utils",
|
||||
":while_op",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -234,6 +235,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_list_utils",
|
||||
srcs = ["tensor_list_utils.cc"],
|
||||
hdrs = ["tensor_list_utils.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "while_op",
|
||||
srcs = ["while_op.cc"],
|
||||
|
@ -14,13 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// XLA TensorList operators.
|
||||
// Tensor lists are represented as tuple consisting of a pre-allocated list
|
||||
// consisting of the tensors (and where dim 0 is the list index), along with a
|
||||
// scalar telling us the current number of elements.
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -28,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
@ -40,27 +39,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
|
||||
TensorShape* tensor_list_shape) {
|
||||
auto shape_or_status = builder->GetShape(op);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
xla::Shape shape = shape_or_status.ValueOrDie();
|
||||
TF_RET_CHECK(shape.IsTuple());
|
||||
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
|
||||
tensor_list_shape);
|
||||
}
|
||||
namespace {
|
||||
|
||||
class TensorListLengthOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp tl = ctx->Input(0);
|
||||
xla::XlaOp index = xla::GetTupleElement(tl, 1);
|
||||
xla::XlaOp index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index));
|
||||
ctx->SetOutput(0, index);
|
||||
}
|
||||
|
||||
@ -117,12 +105,15 @@ class TensorListReserveOp : public XlaOpKernel {
|
||||
int64 num_elements;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
||||
|
||||
xla::XlaOp list;
|
||||
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list));
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &buffer));
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
ctx->SetTensorListOutput(
|
||||
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, num_elements)}));
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, BuildTensorList(
|
||||
buffer, xla::ConstantR0<int32>(ctx->builder(), num_elements),
|
||||
&output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -150,13 +141,15 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
||||
"size. Set the max number of elements."));
|
||||
|
||||
xla::XlaOp list;
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CreateZerosList(ctx, 0, max_num_elements, dtype_, &list));
|
||||
CreateZerosList(ctx, 0, max_num_elements, dtype_, &buffer));
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
ctx->SetTensorListOutput(
|
||||
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, 0)}));
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, BuildTensorList(buffer, xla::ConstantR0<int32>(ctx->builder(), 0),
|
||||
&output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -180,7 +173,7 @@ class TensorListElementShapeOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
|
||||
shape.RemoveDim(0);
|
||||
|
||||
switch (shape_type_) {
|
||||
@ -221,9 +214,10 @@ class TensorListGetItemOp : public XlaOpKernel {
|
||||
xla::XlaOp state = ctx->Input(0);
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
|
||||
|
||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &buffer));
|
||||
xla::XlaOp index = ctx->Input(1);
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
@ -233,7 +227,7 @@ class TensorListGetItemOp : public XlaOpKernel {
|
||||
auto slice_shape = shape.dim_sizes();
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
|
||||
xla::XlaOp read = xla::DynamicSlice(buffer, start_indices, slice_shape);
|
||||
// Remove the leading '1' dimension.
|
||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||
|
||||
@ -255,9 +249,9 @@ class TensorListStackOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp state = ctx->Input(0);
|
||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
||||
ctx->SetOutput(0, ta);
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
|
||||
ctx->SetOutput(0, buffer);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -289,8 +283,11 @@ class TensorListFromTensorOp : public XlaOpKernel {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
const xla::XlaOp tensor = ctx->Input(0);
|
||||
|
||||
ctx->SetTensorListOutput(
|
||||
0, xla::Tuple(b, {tensor, xla::ConstantR0<int32>(b, num_elements)}));
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, BuildTensorList(tensor, xla::ConstantR0<int32>(b, num_elements),
|
||||
&output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -306,30 +303,31 @@ REGISTER_XLA_OP(
|
||||
// Returns the 0'th element of `tuple` containing the list tensor if it has been
|
||||
// initialized already else creates one lazily. This allows lazy initialization
|
||||
// of the list on the first call to SetItem or PushBack.
|
||||
Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple,
|
||||
Status GetInitializedList(const xla::XlaOp& input_list,
|
||||
const TensorShape& element_shape, DataType dtype,
|
||||
xla::XlaOp* list) {
|
||||
*list = xla::GetTupleElement(tuple, 0);
|
||||
TensorShape list_shape;
|
||||
TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape));
|
||||
int64 leading_dim = list_shape.dim_size(0);
|
||||
TensorShape list_element_shape = list_shape;
|
||||
list_element_shape.RemoveDim(0);
|
||||
// This checks for the lazy initialization contract set by CreateEmptyList.
|
||||
// In TensorListReserve if the element_shape is not known at compile time,
|
||||
// it creates a list with shape [leading_dim, 0].
|
||||
if (element_shape != list_element_shape) {
|
||||
if (list_element_shape.num_elements() != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid shape of value in TensorListSetItem. Expected: ",
|
||||
list_element_shape.DebugString(),
|
||||
" Actual: ", element_shape.DebugString());
|
||||
}
|
||||
list_shape = element_shape;
|
||||
list_shape.InsertDim(0, leading_dim);
|
||||
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
|
||||
list_shape.dim_sizes());
|
||||
xla::XlaOp* output_list_buffer) {
|
||||
bool is_already_initialized;
|
||||
TF_RETURN_IF_ERROR(
|
||||
IsTensorListInitialized(input_list, &is_already_initialized));
|
||||
TensorShape input_list_shape;
|
||||
TF_RETURN_IF_ERROR(GetTensorListBufferShape(input_list, &input_list_shape));
|
||||
TensorShape input_list_element_shape = input_list_shape;
|
||||
input_list_element_shape.RemoveDim(0);
|
||||
|
||||
if (is_already_initialized) {
|
||||
TF_RET_CHECK(element_shape == input_list_element_shape);
|
||||
TF_RETURN_IF_ERROR(GetTensorListBuffer(input_list, output_list_buffer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 leading_dim = input_list_shape.dim_size(0);
|
||||
TensorShape output_list_shape = element_shape;
|
||||
output_list_shape.InsertDim(0, leading_dim);
|
||||
|
||||
xla::XlaOp output_list;
|
||||
TF_RETURN_IF_ERROR(
|
||||
InitializeTensorList(input_list, output_list_shape, &output_list));
|
||||
TF_RETURN_IF_ERROR(GetTensorListBuffer(output_list, output_list_buffer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -344,8 +342,10 @@ class TensorListSetItemOp : public XlaOpKernel {
|
||||
xla::XlaOp tl = ctx->Input(0);
|
||||
TensorShape elem_shape = ctx->InputShape(2);
|
||||
|
||||
xla::XlaOp list;
|
||||
OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list));
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx, GetInitializedList(tl, elem_shape, dtype_, &buffer));
|
||||
xla::XlaOp push_index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tl, &push_index));
|
||||
|
||||
xla::XlaOp index = ctx->Input(1);
|
||||
xla::XlaOp value = ctx->Input(2);
|
||||
@ -359,9 +359,11 @@ class TensorListSetItemOp : public XlaOpKernel {
|
||||
slice_shape.InsertDim(0, 1LL);
|
||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||
|
||||
ctx->SetTensorListOutput(
|
||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
|
||||
xla::GetTupleElement(tl, 1)}));
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(ctx, BuildTensorList(xla::DynamicUpdateSlice(buffer, update,
|
||||
start_indices),
|
||||
push_index, &output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -383,11 +385,12 @@ class TensorListPushBackOp : public XlaOpKernel {
|
||||
xla::XlaOp list_tuple = ctx->Input(0);
|
||||
TensorShape elem_shape = ctx->InputShape(1);
|
||||
|
||||
xla::XlaOp list;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list));
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInitializedList(list_tuple, elem_shape, dtype_, &buffer));
|
||||
|
||||
xla::XlaOp index = xla::GetTupleElement(list_tuple, 1);
|
||||
xla::XlaOp index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list_tuple, &index));
|
||||
xla::XlaOp value = ctx->Input(1);
|
||||
|
||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||
@ -399,9 +402,12 @@ class TensorListPushBackOp : public XlaOpKernel {
|
||||
slice_shape.InsertDim(0, 1LL);
|
||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||
|
||||
ctx->SetTensorListOutput(
|
||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
|
||||
index + xla::ConstantR0<int32>(b, 1)}));
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
BuildTensorList(xla::DynamicUpdateSlice(buffer, update, start_indices),
|
||||
index + xla::ConstantR0<int32>(b, 1), &output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -423,10 +429,12 @@ class TensorListPopBackOp : public XlaOpKernel {
|
||||
xla::XlaOp state = ctx->Input(0);
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
|
||||
|
||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
||||
xla::XlaOp index = xla::GetTupleElement(state, 1);
|
||||
xla::XlaOp ta;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &ta));
|
||||
xla::XlaOp index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(state, &index));
|
||||
|
||||
index = index - xla::ConstantR0<int32>(b, 1);
|
||||
|
||||
@ -441,7 +449,9 @@ class TensorListPopBackOp : public XlaOpKernel {
|
||||
// Remove the leading '1' dimension.
|
||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||
|
||||
ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index}));
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(ctx, BuildTensorList(ta, index, &output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
ctx->SetOutput(1, xla::Reshape(read, value_shape));
|
||||
}
|
||||
|
||||
|
100
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
Normal file
100
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
|
||||
return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
|
||||
}
|
||||
|
||||
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
|
||||
xla::XlaOp* output_list) {
|
||||
TF_RET_CHECK(buffer.builder());
|
||||
*output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
|
||||
TF_RET_CHECK(op.builder());
|
||||
*buffer = xla::GetTupleElement(op, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) {
|
||||
TF_RET_CHECK(op.builder());
|
||||
*push_index = xla::GetTupleElement(op, 1);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorListBufferShape(const xla::XlaOp& op,
|
||||
TensorShape* buffer_shape) {
|
||||
TF_RET_CHECK(op.builder());
|
||||
TensorShape shape;
|
||||
TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
|
||||
op.builder()->GetShape(op));
|
||||
return GetTensorListBufferShape(list_tuple_shape, buffer_shape);
|
||||
}
|
||||
|
||||
Status GetTensorListBufferShape(const xla::Shape& list_shape,
|
||||
TensorShape* buffer_shape) {
|
||||
TF_RET_CHECK(list_shape.IsTuple());
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
|
||||
xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) {
|
||||
TensorShape list_shape;
|
||||
TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape));
|
||||
*is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
|
||||
const TensorShape& buffer_shape,
|
||||
xla::XlaOp* output_list) {
|
||||
TensorShape input_buffer_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorListBufferShape(uninitialized_list, &input_buffer_shape));
|
||||
if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of elements in input list does not match buffer size. ",
|
||||
"input list size: ", input_buffer_shape.dim_size(0),
|
||||
"buffer size: ", buffer_shape.dim_size(0));
|
||||
}
|
||||
xla::XlaBuilder* builder = uninitialized_list.builder();
|
||||
xla::XlaOp input_buffer;
|
||||
TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer));
|
||||
TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape,
|
||||
builder->GetShape(input_buffer));
|
||||
auto new_buffer = xla::Broadcast(
|
||||
xla::ConstantLiteral(builder, xla::LiteralUtil::Zero(
|
||||
input_buffer_xla_shape.element_type())),
|
||||
buffer_shape.dim_sizes());
|
||||
xla::XlaOp push_index;
|
||||
TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index));
|
||||
return BuildTensorList(new_buffer, push_index, output_list);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
67
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
Normal file
67
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
||||
|
||||
// TensorList utilities.
|
||||
//
|
||||
// Tensor lists are represented as tuple consisting of a pre-allocated buffer
|
||||
// consisting of the tensors (and where dim 0 is the list index), along with a
|
||||
// scalar telling us the next index to push a value at.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Whether the input expression at `index` corresponds to a TensorList.
|
||||
bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
|
||||
|
||||
// Builds a TensorList from its constituents, `buffer` and `push_index`.
|
||||
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
|
||||
xla::XlaOp* output_list);
|
||||
|
||||
// Returns the buffer for the TensorList.
|
||||
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
|
||||
|
||||
// Returns the push_index for the TensorList.
|
||||
Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index);
|
||||
|
||||
// Returns the shape of the TensorList buffer.
|
||||
Status GetTensorListBufferShape(const xla::XlaOp& op,
|
||||
TensorShape* buffer_shape);
|
||||
|
||||
// Inputs the TensorList shape and returns the buffer shape.
|
||||
Status GetTensorListBufferShape(const xla::Shape& list_shape,
|
||||
TensorShape* buffer_shape);
|
||||
|
||||
// Returns whether the TensorList has been initialized.
|
||||
//
|
||||
// A TensorList is considered initialized if its element_shape is completely
|
||||
// known.
|
||||
Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized);
|
||||
|
||||
// Inputs an uninitialized list and a buffer_shape and returns an initialized
|
||||
// list. The initialized list uses the dtype and push index of the uninitialized
|
||||
// list and is filled with zeros.
|
||||
Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
|
||||
const TensorShape& buffer_shape,
|
||||
xla::XlaOp* output_list);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user