diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index a380715301b..3c0c36d0c4d 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -171,7 +171,7 @@ class ListOpsTest(xla_test.XLATestCase): element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) # Pushing an element with a different shape should raise an error. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): + with self.assertRaisesRegexp(errors.InternalError, "shape"): l = list_ops.tensor_list_push_back(l, 5.) self.evaluate( list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index c790c4c6723..b4d4b4433eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -120,6 +120,7 @@ tf_kernel_library( ":case_op", ":conv_op_helpers", ":if_op", + ":tensor_list_utils", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -234,6 +235,23 @@ cc_library( ], ) +cc_library( + name = "tensor_list_utils", + srcs = ["tensor_list_utils.cc"], + hdrs = ["tensor_list_utils.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + tf_kernel_library( name = "while_op", srcs = ["while_op.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 4feb17d2c86..67a291d7ead 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -14,13 +14,11 @@ limitations under the License. ==============================================================================*/ // XLA TensorList operators. -// Tensor lists are represented as tuple consisting of a pre-allocated list -// consisting of the tensors (and where dim 0 is the list index), along with a -// scalar telling us the current number of elements. #include #include +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -28,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -40,27 +39,16 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { -namespace { -Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, - TensorShape* tensor_list_shape) { - auto shape_or_status = builder->GetShape(op); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(shape.IsTuple()); - return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), - tensor_list_shape); -} +namespace { class TensorListLengthOp : public XlaOpKernel { public: explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp tl = ctx->Input(0); - xla::XlaOp index = xla::GetTupleElement(tl, 1); + xla::XlaOp index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index)); ctx->SetOutput(0, index); } @@ -117,12 +105,15 @@ class TensorListReserveOp : public XlaOpKernel { int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); - xla::XlaOp list; - OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list)); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &buffer)); - xla::XlaBuilder* b = ctx->builder(); - ctx->SetTensorListOutput( - 0, xla::Tuple(b, {list, xla::ConstantR0(b, num_elements)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, BuildTensorList( + buffer, xla::ConstantR0(ctx->builder(), num_elements), + &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -150,13 +141,15 @@ class EmptyTensorListOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires a fixed tensor list " "size. Set the max number of elements.")); - xla::XlaOp list; + xla::XlaOp buffer; OP_REQUIRES_OK(ctx, - CreateZerosList(ctx, 0, max_num_elements, dtype_, &list)); + CreateZerosList(ctx, 0, max_num_elements, dtype_, &buffer)); - xla::XlaBuilder* b = ctx->builder(); - ctx->SetTensorListOutput( - 0, xla::Tuple(b, {list, xla::ConstantR0(b, 0)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, BuildTensorList(buffer, xla::ConstantR0(ctx->builder(), 0), + &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -180,7 +173,7 @@ class TensorListElementShapeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); shape.RemoveDim(0); switch (shape_type_) { @@ -221,9 +214,10 @@ class TensorListGetItemOp : public XlaOpKernel { xla::XlaOp state = ctx->Input(0); TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); - xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &buffer)); xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -233,7 +227,7 @@ class TensorListGetItemOp : public XlaOpKernel { auto slice_shape = shape.dim_sizes(); slice_shape[0] = 1LL; - xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(buffer, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); @@ -255,9 +249,9 @@ class TensorListStackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp state = ctx->Input(0); - xla::XlaOp ta = xla::GetTupleElement(state, 0); - ctx->SetOutput(0, ta); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer)); + ctx->SetOutput(0, buffer); } private: @@ -289,8 +283,11 @@ class TensorListFromTensorOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const xla::XlaOp tensor = ctx->Input(0); - ctx->SetTensorListOutput( - 0, xla::Tuple(b, {tensor, xla::ConstantR0(b, num_elements)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, BuildTensorList(tensor, xla::ConstantR0(b, num_elements), + &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -306,30 +303,31 @@ REGISTER_XLA_OP( // Returns the 0'th element of `tuple` containing the list tensor if it has been // initialized already else creates one lazily. This allows lazy initialization // of the list on the first call to SetItem or PushBack. -Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple, +Status GetInitializedList(const xla::XlaOp& input_list, const TensorShape& element_shape, DataType dtype, - xla::XlaOp* list) { - *list = xla::GetTupleElement(tuple, 0); - TensorShape list_shape; - TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape)); - int64 leading_dim = list_shape.dim_size(0); - TensorShape list_element_shape = list_shape; - list_element_shape.RemoveDim(0); - // This checks for the lazy initialization contract set by CreateEmptyList. - // In TensorListReserve if the element_shape is not known at compile time, - // it creates a list with shape [leading_dim, 0]. - if (element_shape != list_element_shape) { - if (list_element_shape.num_elements() != 0) { - return errors::InvalidArgument( - "Invalid shape of value in TensorListSetItem. Expected: ", - list_element_shape.DebugString(), - " Actual: ", element_shape.DebugString()); - } - list_shape = element_shape; - list_shape.InsertDim(0, leading_dim); - *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), - list_shape.dim_sizes()); + xla::XlaOp* output_list_buffer) { + bool is_already_initialized; + TF_RETURN_IF_ERROR( + IsTensorListInitialized(input_list, &is_already_initialized)); + TensorShape input_list_shape; + TF_RETURN_IF_ERROR(GetTensorListBufferShape(input_list, &input_list_shape)); + TensorShape input_list_element_shape = input_list_shape; + input_list_element_shape.RemoveDim(0); + + if (is_already_initialized) { + TF_RET_CHECK(element_shape == input_list_element_shape); + TF_RETURN_IF_ERROR(GetTensorListBuffer(input_list, output_list_buffer)); + return Status::OK(); } + + int64 leading_dim = input_list_shape.dim_size(0); + TensorShape output_list_shape = element_shape; + output_list_shape.InsertDim(0, leading_dim); + + xla::XlaOp output_list; + TF_RETURN_IF_ERROR( + InitializeTensorList(input_list, output_list_shape, &output_list)); + TF_RETURN_IF_ERROR(GetTensorListBuffer(output_list, output_list_buffer)); return Status::OK(); } @@ -344,8 +342,10 @@ class TensorListSetItemOp : public XlaOpKernel { xla::XlaOp tl = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(2); - xla::XlaOp list; - OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list)); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetInitializedList(tl, elem_shape, dtype_, &buffer)); + xla::XlaOp push_index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tl, &push_index)); xla::XlaOp index = ctx->Input(1); xla::XlaOp value = ctx->Input(2); @@ -359,9 +359,11 @@ class TensorListSetItemOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), - xla::GetTupleElement(tl, 1)})); + xla::XlaOp output_list; + OP_REQUIRES_OK(ctx, BuildTensorList(xla::DynamicUpdateSlice(buffer, update, + start_indices), + push_index, &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -383,11 +385,12 @@ class TensorListPushBackOp : public XlaOpKernel { xla::XlaOp list_tuple = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp list; - OP_REQUIRES_OK( - ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list)); + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, + GetInitializedList(list_tuple, elem_shape, dtype_, &buffer)); - xla::XlaOp index = xla::GetTupleElement(list_tuple, 1); + xla::XlaOp index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list_tuple, &index)); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -399,9 +402,12 @@ class TensorListPushBackOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), - index + xla::ConstantR0(b, 1)})); + xla::XlaOp output_list; + OP_REQUIRES_OK( + ctx, + BuildTensorList(xla::DynamicUpdateSlice(buffer, update, start_indices), + index + xla::ConstantR0(b, 1), &output_list)); + ctx->SetTensorListOutput(0, output_list); } private: @@ -423,10 +429,12 @@ class TensorListPopBackOp : public XlaOpKernel { xla::XlaOp state = ctx->Input(0); TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); - xla::XlaOp ta = xla::GetTupleElement(state, 0); - xla::XlaOp index = xla::GetTupleElement(state, 1); + xla::XlaOp ta; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &ta)); + xla::XlaOp index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(state, &index)); index = index - xla::ConstantR0(b, 1); @@ -441,7 +449,9 @@ class TensorListPopBackOp : public XlaOpKernel { // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index})); + xla::XlaOp output_list; + OP_REQUIRES_OK(ctx, BuildTensorList(ta, index, &output_list)); + ctx->SetTensorListOutput(0, output_list); ctx->SetOutput(1, xla::Reshape(read, value_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc new file mode 100644 index 00000000000..aa6ee2ac35e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +bool IsTensorListInput(XlaOpKernelContext* ctx, int index) { + return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList; +} + +Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, + xla::XlaOp* output_list) { + TF_RET_CHECK(buffer.builder()); + *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); + return Status::OK(); +} + +Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) { + TF_RET_CHECK(op.builder()); + *buffer = xla::GetTupleElement(op, 0); + return Status::OK(); +} + +Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) { + TF_RET_CHECK(op.builder()); + *push_index = xla::GetTupleElement(op, 1); + return Status::OK(); +} + +Status GetTensorListBufferShape(const xla::XlaOp& op, + TensorShape* buffer_shape) { + TF_RET_CHECK(op.builder()); + TensorShape shape; + TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape, + op.builder()->GetShape(op)); + return GetTensorListBufferShape(list_tuple_shape, buffer_shape); +} + +Status GetTensorListBufferShape(const xla::Shape& list_shape, + TensorShape* buffer_shape) { + TF_RET_CHECK(list_shape.IsTuple()); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape)); + return Status::OK(); +} + +Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) { + TensorShape list_shape; + TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape)); + *is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0); + return Status::OK(); +} + +Status InitializeTensorList(const xla::XlaOp& uninitialized_list, + const TensorShape& buffer_shape, + xla::XlaOp* output_list) { + TensorShape input_buffer_shape; + TF_RETURN_IF_ERROR( + GetTensorListBufferShape(uninitialized_list, &input_buffer_shape)); + if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) { + return errors::InvalidArgument( + "Number of elements in input list does not match buffer size. ", + "input list size: ", input_buffer_shape.dim_size(0), + "buffer size: ", buffer_shape.dim_size(0)); + } + xla::XlaBuilder* builder = uninitialized_list.builder(); + xla::XlaOp input_buffer; + TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer)); + TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape, + builder->GetShape(input_buffer)); + auto new_buffer = xla::Broadcast( + xla::ConstantLiteral(builder, xla::LiteralUtil::Zero( + input_buffer_xla_shape.element_type())), + buffer_shape.dim_sizes()); + xla::XlaOp push_index; + TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index)); + return BuildTensorList(new_buffer, push_index, output_list); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h new file mode 100644 index 00000000000..937af6f8d77 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ + +// TensorList utilities. +// +// Tensor lists are represented as tuple consisting of a pre-allocated buffer +// consisting of the tensors (and where dim 0 is the list index), along with a +// scalar telling us the next index to push a value at. + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +// Whether the input expression at `index` corresponds to a TensorList. +bool IsTensorListInput(XlaOpKernelContext* ctx, int index); + +// Builds a TensorList from its constituents, `buffer` and `push_index`. +Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, + xla::XlaOp* output_list); + +// Returns the buffer for the TensorList. +Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer); + +// Returns the push_index for the TensorList. +Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index); + +// Returns the shape of the TensorList buffer. +Status GetTensorListBufferShape(const xla::XlaOp& op, + TensorShape* buffer_shape); + +// Inputs the TensorList shape and returns the buffer shape. +Status GetTensorListBufferShape(const xla::Shape& list_shape, + TensorShape* buffer_shape); + +// Returns whether the TensorList has been initialized. +// +// A TensorList is considered initialized if its element_shape is completely +// known. +Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized); + +// Inputs an uninitialized list and a buffer_shape and returns an initialized +// list. The initialized list uses the dtype and push index of the uninitialized +// list and is filled with zeros. +Status InitializeTensorList(const xla::XlaOp& uninitialized_list, + const TensorShape& buffer_shape, + xla::XlaOp* output_list); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_