Adds utility functions for dealing with TensorLists in XLA (No functional change to tensor list kernels)
These will be useful in a follow-up change on lazy initialization of lists passed to XlaWhile. PiperOrigin-RevId: 236703345
This commit is contained in:
parent
fe8bf74412
commit
b7628a22f0
@ -171,7 +171,7 @@ class ListOpsTest(xla_test.XLATestCase):
|
|||||||
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
|
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
|
||||||
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
|
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
|
||||||
# Pushing an element with a different shape should raise an error.
|
# Pushing an element with a different shape should raise an error.
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"):
|
with self.assertRaisesRegexp(errors.InternalError, "shape"):
|
||||||
l = list_ops.tensor_list_push_back(l, 5.)
|
l = list_ops.tensor_list_push_back(l, 5.)
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
|
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
|
||||||
|
@ -120,6 +120,7 @@ tf_kernel_library(
|
|||||||
":case_op",
|
":case_op",
|
||||||
":conv_op_helpers",
|
":conv_op_helpers",
|
||||||
":if_op",
|
":if_op",
|
||||||
|
":tensor_list_utils",
|
||||||
":while_op",
|
":while_op",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
@ -234,6 +235,23 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_list_utils",
|
||||||
|
srcs = ["tensor_list_utils.cc"],
|
||||||
|
hdrs = ["tensor_list_utils.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "while_op",
|
name = "while_op",
|
||||||
srcs = ["while_op.cc"],
|
srcs = ["while_op.cc"],
|
||||||
|
@ -14,13 +14,11 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// XLA TensorList operators.
|
// XLA TensorList operators.
|
||||||
// Tensor lists are represented as tuple consisting of a pre-allocated list
|
|
||||||
// consisting of the tensors (and where dim 0 is the list index), along with a
|
|
||||||
// scalar telling us the current number of elements.
|
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
@ -28,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
@ -40,27 +39,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
|
||||||
|
|
||||||
Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
|
namespace {
|
||||||
TensorShape* tensor_list_shape) {
|
|
||||||
auto shape_or_status = builder->GetShape(op);
|
|
||||||
if (!shape_or_status.ok()) {
|
|
||||||
return shape_or_status.status();
|
|
||||||
}
|
|
||||||
xla::Shape shape = shape_or_status.ValueOrDie();
|
|
||||||
TF_RET_CHECK(shape.IsTuple());
|
|
||||||
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
|
|
||||||
tensor_list_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TensorListLengthOp : public XlaOpKernel {
|
class TensorListLengthOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaOp tl = ctx->Input(0);
|
xla::XlaOp index;
|
||||||
xla::XlaOp index = xla::GetTupleElement(tl, 1);
|
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index));
|
||||||
ctx->SetOutput(0, index);
|
ctx->SetOutput(0, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,12 +105,15 @@ class TensorListReserveOp : public XlaOpKernel {
|
|||||||
int64 num_elements;
|
int64 num_elements;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
||||||
|
|
||||||
xla::XlaOp list;
|
xla::XlaOp buffer;
|
||||||
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list));
|
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &buffer));
|
||||||
|
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaOp output_list;
|
||||||
ctx->SetTensorListOutput(
|
OP_REQUIRES_OK(
|
||||||
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, num_elements)}));
|
ctx, BuildTensorList(
|
||||||
|
buffer, xla::ConstantR0<int32>(ctx->builder(), num_elements),
|
||||||
|
&output_list));
|
||||||
|
ctx->SetTensorListOutput(0, output_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -150,13 +141,15 @@ class EmptyTensorListOp : public XlaOpKernel {
|
|||||||
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
||||||
"size. Set the max number of elements."));
|
"size. Set the max number of elements."));
|
||||||
|
|
||||||
xla::XlaOp list;
|
xla::XlaOp buffer;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
CreateZerosList(ctx, 0, max_num_elements, dtype_, &list));
|
CreateZerosList(ctx, 0, max_num_elements, dtype_, &buffer));
|
||||||
|
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaOp output_list;
|
||||||
ctx->SetTensorListOutput(
|
OP_REQUIRES_OK(
|
||||||
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, 0)}));
|
ctx, BuildTensorList(buffer, xla::ConstantR0<int32>(ctx->builder(), 0),
|
||||||
|
&output_list));
|
||||||
|
ctx->SetTensorListOutput(0, output_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -180,7 +173,7 @@ class TensorListElementShapeOp : public XlaOpKernel {
|
|||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaBuilder* b = ctx->builder();
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
|
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
|
||||||
shape.RemoveDim(0);
|
shape.RemoveDim(0);
|
||||||
|
|
||||||
switch (shape_type_) {
|
switch (shape_type_) {
|
||||||
@ -221,9 +214,10 @@ class TensorListGetItemOp : public XlaOpKernel {
|
|||||||
xla::XlaOp state = ctx->Input(0);
|
xla::XlaOp state = ctx->Input(0);
|
||||||
|
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
|
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
|
||||||
|
|
||||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
xla::XlaOp buffer;
|
||||||
|
OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &buffer));
|
||||||
xla::XlaOp index = ctx->Input(1);
|
xla::XlaOp index = ctx->Input(1);
|
||||||
|
|
||||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||||
@ -233,7 +227,7 @@ class TensorListGetItemOp : public XlaOpKernel {
|
|||||||
auto slice_shape = shape.dim_sizes();
|
auto slice_shape = shape.dim_sizes();
|
||||||
slice_shape[0] = 1LL;
|
slice_shape[0] = 1LL;
|
||||||
|
|
||||||
xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
|
xla::XlaOp read = xla::DynamicSlice(buffer, start_indices, slice_shape);
|
||||||
// Remove the leading '1' dimension.
|
// Remove the leading '1' dimension.
|
||||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||||
|
|
||||||
@ -255,9 +249,9 @@ class TensorListStackOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaOp state = ctx->Input(0);
|
xla::XlaOp buffer;
|
||||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
|
||||||
ctx->SetOutput(0, ta);
|
ctx->SetOutput(0, buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -289,8 +283,11 @@ class TensorListFromTensorOp : public XlaOpKernel {
|
|||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaBuilder* b = ctx->builder();
|
||||||
const xla::XlaOp tensor = ctx->Input(0);
|
const xla::XlaOp tensor = ctx->Input(0);
|
||||||
|
|
||||||
ctx->SetTensorListOutput(
|
xla::XlaOp output_list;
|
||||||
0, xla::Tuple(b, {tensor, xla::ConstantR0<int32>(b, num_elements)}));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, BuildTensorList(tensor, xla::ConstantR0<int32>(b, num_elements),
|
||||||
|
&output_list));
|
||||||
|
ctx->SetTensorListOutput(0, output_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -306,30 +303,31 @@ REGISTER_XLA_OP(
|
|||||||
// Returns the 0'th element of `tuple` containing the list tensor if it has been
|
// Returns the 0'th element of `tuple` containing the list tensor if it has been
|
||||||
// initialized already else creates one lazily. This allows lazy initialization
|
// initialized already else creates one lazily. This allows lazy initialization
|
||||||
// of the list on the first call to SetItem or PushBack.
|
// of the list on the first call to SetItem or PushBack.
|
||||||
Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple,
|
Status GetInitializedList(const xla::XlaOp& input_list,
|
||||||
const TensorShape& element_shape, DataType dtype,
|
const TensorShape& element_shape, DataType dtype,
|
||||||
xla::XlaOp* list) {
|
xla::XlaOp* output_list_buffer) {
|
||||||
*list = xla::GetTupleElement(tuple, 0);
|
bool is_already_initialized;
|
||||||
TensorShape list_shape;
|
TF_RETURN_IF_ERROR(
|
||||||
TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape));
|
IsTensorListInitialized(input_list, &is_already_initialized));
|
||||||
int64 leading_dim = list_shape.dim_size(0);
|
TensorShape input_list_shape;
|
||||||
TensorShape list_element_shape = list_shape;
|
TF_RETURN_IF_ERROR(GetTensorListBufferShape(input_list, &input_list_shape));
|
||||||
list_element_shape.RemoveDim(0);
|
TensorShape input_list_element_shape = input_list_shape;
|
||||||
// This checks for the lazy initialization contract set by CreateEmptyList.
|
input_list_element_shape.RemoveDim(0);
|
||||||
// In TensorListReserve if the element_shape is not known at compile time,
|
|
||||||
// it creates a list with shape [leading_dim, 0].
|
if (is_already_initialized) {
|
||||||
if (element_shape != list_element_shape) {
|
TF_RET_CHECK(element_shape == input_list_element_shape);
|
||||||
if (list_element_shape.num_elements() != 0) {
|
TF_RETURN_IF_ERROR(GetTensorListBuffer(input_list, output_list_buffer));
|
||||||
return errors::InvalidArgument(
|
return Status::OK();
|
||||||
"Invalid shape of value in TensorListSetItem. Expected: ",
|
|
||||||
list_element_shape.DebugString(),
|
|
||||||
" Actual: ", element_shape.DebugString());
|
|
||||||
}
|
|
||||||
list_shape = element_shape;
|
|
||||||
list_shape.InsertDim(0, leading_dim);
|
|
||||||
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
|
|
||||||
list_shape.dim_sizes());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 leading_dim = input_list_shape.dim_size(0);
|
||||||
|
TensorShape output_list_shape = element_shape;
|
||||||
|
output_list_shape.InsertDim(0, leading_dim);
|
||||||
|
|
||||||
|
xla::XlaOp output_list;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
InitializeTensorList(input_list, output_list_shape, &output_list));
|
||||||
|
TF_RETURN_IF_ERROR(GetTensorListBuffer(output_list, output_list_buffer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -344,8 +342,10 @@ class TensorListSetItemOp : public XlaOpKernel {
|
|||||||
xla::XlaOp tl = ctx->Input(0);
|
xla::XlaOp tl = ctx->Input(0);
|
||||||
TensorShape elem_shape = ctx->InputShape(2);
|
TensorShape elem_shape = ctx->InputShape(2);
|
||||||
|
|
||||||
xla::XlaOp list;
|
xla::XlaOp buffer;
|
||||||
OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list));
|
OP_REQUIRES_OK(ctx, GetInitializedList(tl, elem_shape, dtype_, &buffer));
|
||||||
|
xla::XlaOp push_index;
|
||||||
|
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tl, &push_index));
|
||||||
|
|
||||||
xla::XlaOp index = ctx->Input(1);
|
xla::XlaOp index = ctx->Input(1);
|
||||||
xla::XlaOp value = ctx->Input(2);
|
xla::XlaOp value = ctx->Input(2);
|
||||||
@ -359,9 +359,11 @@ class TensorListSetItemOp : public XlaOpKernel {
|
|||||||
slice_shape.InsertDim(0, 1LL);
|
slice_shape.InsertDim(0, 1LL);
|
||||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||||
|
|
||||||
ctx->SetTensorListOutput(
|
xla::XlaOp output_list;
|
||||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
|
OP_REQUIRES_OK(ctx, BuildTensorList(xla::DynamicUpdateSlice(buffer, update,
|
||||||
xla::GetTupleElement(tl, 1)}));
|
start_indices),
|
||||||
|
push_index, &output_list));
|
||||||
|
ctx->SetTensorListOutput(0, output_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -383,11 +385,12 @@ class TensorListPushBackOp : public XlaOpKernel {
|
|||||||
xla::XlaOp list_tuple = ctx->Input(0);
|
xla::XlaOp list_tuple = ctx->Input(0);
|
||||||
TensorShape elem_shape = ctx->InputShape(1);
|
TensorShape elem_shape = ctx->InputShape(1);
|
||||||
|
|
||||||
xla::XlaOp list;
|
xla::XlaOp buffer;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(ctx,
|
||||||
ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list));
|
GetInitializedList(list_tuple, elem_shape, dtype_, &buffer));
|
||||||
|
|
||||||
xla::XlaOp index = xla::GetTupleElement(list_tuple, 1);
|
xla::XlaOp index;
|
||||||
|
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list_tuple, &index));
|
||||||
xla::XlaOp value = ctx->Input(1);
|
xla::XlaOp value = ctx->Input(1);
|
||||||
|
|
||||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||||
@ -399,9 +402,12 @@ class TensorListPushBackOp : public XlaOpKernel {
|
|||||||
slice_shape.InsertDim(0, 1LL);
|
slice_shape.InsertDim(0, 1LL);
|
||||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||||
|
|
||||||
ctx->SetTensorListOutput(
|
xla::XlaOp output_list;
|
||||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
|
OP_REQUIRES_OK(
|
||||||
index + xla::ConstantR0<int32>(b, 1)}));
|
ctx,
|
||||||
|
BuildTensorList(xla::DynamicUpdateSlice(buffer, update, start_indices),
|
||||||
|
index + xla::ConstantR0<int32>(b, 1), &output_list));
|
||||||
|
ctx->SetTensorListOutput(0, output_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -423,10 +429,12 @@ class TensorListPopBackOp : public XlaOpKernel {
|
|||||||
xla::XlaOp state = ctx->Input(0);
|
xla::XlaOp state = ctx->Input(0);
|
||||||
|
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
|
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape));
|
||||||
|
|
||||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
xla::XlaOp ta;
|
||||||
xla::XlaOp index = xla::GetTupleElement(state, 1);
|
OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &ta));
|
||||||
|
xla::XlaOp index;
|
||||||
|
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(state, &index));
|
||||||
|
|
||||||
index = index - xla::ConstantR0<int32>(b, 1);
|
index = index - xla::ConstantR0<int32>(b, 1);
|
||||||
|
|
||||||
@ -441,7 +449,9 @@ class TensorListPopBackOp : public XlaOpKernel {
|
|||||||
// Remove the leading '1' dimension.
|
// Remove the leading '1' dimension.
|
||||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||||
|
|
||||||
ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index}));
|
xla::XlaOp output_list;
|
||||||
|
OP_REQUIRES_OK(ctx, BuildTensorList(ta, index, &output_list));
|
||||||
|
ctx->SetTensorListOutput(0, output_list);
|
||||||
ctx->SetOutput(1, xla::Reshape(read, value_shape));
|
ctx->SetOutput(1, xla::Reshape(read, value_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
100
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
Normal file
100
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
|
||||||
|
return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
|
||||||
|
xla::XlaOp* output_list) {
|
||||||
|
TF_RET_CHECK(buffer.builder());
|
||||||
|
*output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
|
||||||
|
TF_RET_CHECK(op.builder());
|
||||||
|
*buffer = xla::GetTupleElement(op, 0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) {
|
||||||
|
TF_RET_CHECK(op.builder());
|
||||||
|
*push_index = xla::GetTupleElement(op, 1);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetTensorListBufferShape(const xla::XlaOp& op,
|
||||||
|
TensorShape* buffer_shape) {
|
||||||
|
TF_RET_CHECK(op.builder());
|
||||||
|
TensorShape shape;
|
||||||
|
TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
|
||||||
|
op.builder()->GetShape(op));
|
||||||
|
return GetTensorListBufferShape(list_tuple_shape, buffer_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetTensorListBufferShape(const xla::Shape& list_shape,
|
||||||
|
TensorShape* buffer_shape) {
|
||||||
|
TF_RET_CHECK(list_shape.IsTuple());
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
|
||||||
|
xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) {
|
||||||
|
TensorShape list_shape;
|
||||||
|
TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape));
|
||||||
|
*is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
|
||||||
|
const TensorShape& buffer_shape,
|
||||||
|
xla::XlaOp* output_list) {
|
||||||
|
TensorShape input_buffer_shape;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetTensorListBufferShape(uninitialized_list, &input_buffer_shape));
|
||||||
|
if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Number of elements in input list does not match buffer size. ",
|
||||||
|
"input list size: ", input_buffer_shape.dim_size(0),
|
||||||
|
"buffer size: ", buffer_shape.dim_size(0));
|
||||||
|
}
|
||||||
|
xla::XlaBuilder* builder = uninitialized_list.builder();
|
||||||
|
xla::XlaOp input_buffer;
|
||||||
|
TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer));
|
||||||
|
TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape,
|
||||||
|
builder->GetShape(input_buffer));
|
||||||
|
auto new_buffer = xla::Broadcast(
|
||||||
|
xla::ConstantLiteral(builder, xla::LiteralUtil::Zero(
|
||||||
|
input_buffer_xla_shape.element_type())),
|
||||||
|
buffer_shape.dim_sizes());
|
||||||
|
xla::XlaOp push_index;
|
||||||
|
TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index));
|
||||||
|
return BuildTensorList(new_buffer, push_index, output_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
67
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
Normal file
67
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
||||||
|
|
||||||
|
// TensorList utilities.
|
||||||
|
//
|
||||||
|
// Tensor lists are represented as tuple consisting of a pre-allocated buffer
|
||||||
|
// consisting of the tensors (and where dim 0 is the list index), along with a
|
||||||
|
// scalar telling us the next index to push a value at.
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Whether the input expression at `index` corresponds to a TensorList.
|
||||||
|
bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
|
||||||
|
|
||||||
|
// Builds a TensorList from its constituents, `buffer` and `push_index`.
|
||||||
|
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
|
||||||
|
xla::XlaOp* output_list);
|
||||||
|
|
||||||
|
// Returns the buffer for the TensorList.
|
||||||
|
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
|
||||||
|
|
||||||
|
// Returns the push_index for the TensorList.
|
||||||
|
Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index);
|
||||||
|
|
||||||
|
// Returns the shape of the TensorList buffer.
|
||||||
|
Status GetTensorListBufferShape(const xla::XlaOp& op,
|
||||||
|
TensorShape* buffer_shape);
|
||||||
|
|
||||||
|
// Inputs the TensorList shape and returns the buffer shape.
|
||||||
|
Status GetTensorListBufferShape(const xla::Shape& list_shape,
|
||||||
|
TensorShape* buffer_shape);
|
||||||
|
|
||||||
|
// Returns whether the TensorList has been initialized.
|
||||||
|
//
|
||||||
|
// A TensorList is considered initialized if its element_shape is completely
|
||||||
|
// known.
|
||||||
|
Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized);
|
||||||
|
|
||||||
|
// Inputs an uninitialized list and a buffer_shape and returns an initialized
|
||||||
|
// list. The initialized list uses the dtype and push index of the uninitialized
|
||||||
|
// list and is filled with zeros.
|
||||||
|
Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
|
||||||
|
const TensorShape& buffer_shape,
|
||||||
|
xla::XlaOp* output_list);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user