From ce2b635fcd19774bf8b20188051f8215f580588a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Sep 2019 09:47:58 -0700 Subject: [PATCH] [TF:XLA] Support batch_dims in ResourceGatherOp. PiperOrigin-RevId: 266950621 --- .../compiler/tf2xla/kernels/gather_op.cc | 152 ++++++++++-------- .../tf2xla/kernels/gather_op_helpers.h | 7 + .../compiler/tf2xla/kernels/variable_ops.cc | 26 ++- tensorflow/python/kernel_tests/BUILD | 1 - .../resource_variable_ops_test.py | 4 +- 5 files changed, 106 insertions(+), 84 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 489ffd3fdad..84a0e78ff6e 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -25,8 +25,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -150,6 +152,85 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, return Status::OK(); } +Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + const xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, xla::XlaOp* gather_output) { + auto indices = context->Input(1); + auto indices_shape = context->InputShape(1); + + absl::optional axis; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + if (!TensorShapeUtils::IsScalar(axis_shape)) { + return errors::InvalidArgument("axis must be scalar"); + } + DataType axis_type = context->input_type(2); + if (axis_type != DT_INT32 && axis_type != DT_INT64) { + return errors::InvalidArgument("axis must be int32 or int64"); + } + + int64 axis_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input)); + + const auto params_dims = input_shape.dims(); + if (-params_dims > axis_input || axis_input >= params_dims) { + return errors::InvalidArgument("Expected axis in the range [", + -params_dims, ", ", params_dims, + "), but got ", axis_input); + } + if (axis_input < 0) { + axis_input += params_dims; + } + axis = axis_input; + } + + if (batch_dims != 0) { + if (batch_dims < 0) { + batch_dims = indices_shape.dims() + batch_dims; + } + + axis = axis.value_or(batch_dims); + + if (batch_dims < -indices_shape.dims() || + batch_dims >= indices_shape.dims()) { + return errors::InvalidArgument( + "Expected batch_dims in the range [", -indices_shape.dims(), ", ", + indices_shape.dims(), "), but got ", batch_dims); + } + + if (batch_dims >= input_shape.dims()) { + return errors::InvalidArgument("batch_dims (", batch_dims, + ") must be less than rank(input) (", + input_shape.dims(), ")."); + } + + if (*axis < batch_dims) { + return errors::InvalidArgument("batch_dims (", batch_dims, + ") must be less than or equal to ", + "axis (", *axis, ")."); + } + } + + axis = axis.value_or(0); + DataType index_type = context->input_type(1); + if (index_type != DT_INT32 && index_type != DT_INT64) { + return errors::InvalidArgument("indices must be int32 or int64"); + } + + xla::XlaOp gather; + if (batch_dims > 0) { + *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims); + } else { + // XlaGather() manages degenerate cases, like empty-indices, which are + // error conditions and caught above if batch_dims is not 0. + TF_RETURN_IF_ERROR( + XlaGather(input, input_shape, indices, indices_shape, *axis, + /*indices_are_nd=*/false, context->input_type(0), index_type, + context->builder(), gather_output)); + } + return Status::OK(); +} class GatherOp : public XlaOpKernel { public: explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { @@ -164,76 +245,11 @@ class GatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { auto input = context->Input(0); auto input_shape = context->InputShape(0); - auto indices = context->Input(1); - auto indices_shape = context->InputShape(1); - - absl::optional axis; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - int64 axis_input; - OP_REQUIRES_OK(context, - context->ConstantInputAsIntScalar(2, &axis_input)); - - const auto params_dims = input_shape.dims(); - OP_REQUIRES(context, - -params_dims <= axis_input && axis_input < params_dims, - errors::InvalidArgument("Expected axis in the range [", - -params_dims, ", ", params_dims, - "), but got ", axis_input)); - if (axis_input < 0) { - axis_input += params_dims; - } - axis = axis_input; - } - - if (batch_dims_ != 0) { - if (batch_dims_ < 0) { - batch_dims_ = indices_shape.dims() + batch_dims_; - } - - axis = axis.value_or(batch_dims_); - - OP_REQUIRES(context, - batch_dims_ >= -indices_shape.dims() && - batch_dims_ < indices_shape.dims(), - errors::InvalidArgument("Expected batch_dims in the range [", - -indices_shape.dims(), ", ", - indices_shape.dims(), "), but got ", - batch_dims_)); - - OP_REQUIRES(context, batch_dims_ < input_shape.dims(), - errors::InvalidArgument("batch_dims (", batch_dims_, - ") must be less than rank(input) (", - input_shape.dims(), ").")); - - OP_REQUIRES(context, *axis >= batch_dims_, - errors::InvalidArgument("batch_dims (", batch_dims_, - ") must be less than or equal to ", - "axis (", *axis, ").")); - } - - axis = axis.value_or(0); - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("indices must be int32 or int64")); xla::XlaOp gather; - if (batch_dims_ > 0) { - gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_); - } else { - // XlaGather() manages degenerate cases, like empty-indices, which are - // error conditions and caught above if batch_dims is not 0. - OP_REQUIRES_OK( - context, XlaGather(input, input_shape, indices, indices_shape, *axis, - /*indices_are_nd=*/false, input_type(0), - index_type, context->builder(), &gather)); - } + OP_REQUIRES_OK(context, + XlaGatherWithBatchDimsOpImpl(context, input, input_shape, + batch_dims_, &gather)); context->SetOutput(0, gather); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 92346283c31..7bd25230d46 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -39,6 +39,13 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, DataType index_type, xla::XlaBuilder* builder, xla::XlaOp* gather_output); +// The implementation of Gather and ResourceGather through XLA. Uses `input` as +// the input instead of context->input(0) in order to allow ResourceGather to +// handle obtaining the data from the ResourceVariable. +Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + const xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, xla::XlaOp* gather_output); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 7b4125ab76e..60424f85840 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -122,27 +123,24 @@ REGISTER_XLA_OP( class ResourceGatherOp : public XlaOpKernel { public: - explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_)); + } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - DataType type = ctx->expected_output_dtype(0); - TensorShape resource_shape; - xla::XlaOp resource_handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, - &resource_handle)); + TensorShape input_shape; + xla::XlaOp input; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input)); - auto indices = ctx->Input(1); - auto indices_shape = ctx->InputShape(1); - DataType index_type = ctx->input_type(1); xla::XlaOp gather; - OP_REQUIRES_OK( - ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, - /*axis=*/0, /*indices_are_nd=*/false, type, index_type, - builder, &gather)); + OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape, + batch_dims_, &gather)); ctx->SetOutput(0, gather); } + + private: + int32 batch_dims_; }; REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index e5b741b8077..18682e03e59 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -861,7 +861,6 @@ cuda_py_test( ], # TODO(b/128347673): Re-enable. tags = ["no_windows"], - xla_enable_strict_auto_jit = True, ) tf_py_test( diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 70c6c7ecfbc..14a4c531ccc 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -986,7 +986,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", container=ops.get_default_graph()._container) - with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): + with self.assertRaisesOpError( + "(Resource .*/var5/.* does not exist|Read of uninitialized variable)" + ): resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval() @test_util.run_deprecated_v1