diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 84a0e78ff6e..489ffd3fdad 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -25,10 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -152,85 +150,6 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, return Status::OK(); } -Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, - const xla::XlaOp input, - const TensorShape& input_shape, - int batch_dims, xla::XlaOp* gather_output) { - auto indices = context->Input(1); - auto indices_shape = context->InputShape(1); - - absl::optional axis; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - if (!TensorShapeUtils::IsScalar(axis_shape)) { - return errors::InvalidArgument("axis must be scalar"); - } - DataType axis_type = context->input_type(2); - if (axis_type != DT_INT32 && axis_type != DT_INT64) { - return errors::InvalidArgument("axis must be int32 or int64"); - } - - int64 axis_input; - TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input)); - - const auto params_dims = input_shape.dims(); - if (-params_dims > axis_input || axis_input >= params_dims) { - return errors::InvalidArgument("Expected axis in the range [", - -params_dims, ", ", params_dims, - "), but got ", axis_input); - } - if (axis_input < 0) { - axis_input += params_dims; - } - axis = axis_input; - } - - if (batch_dims != 0) { - if (batch_dims < 0) { - batch_dims = indices_shape.dims() + batch_dims; - } - - axis = axis.value_or(batch_dims); - - if (batch_dims < -indices_shape.dims() || - batch_dims >= indices_shape.dims()) { - return errors::InvalidArgument( - "Expected batch_dims in the range [", -indices_shape.dims(), ", ", - indices_shape.dims(), "), but got ", batch_dims); - } - - if (batch_dims >= input_shape.dims()) { - return errors::InvalidArgument("batch_dims (", batch_dims, - ") must be less than rank(input) (", - input_shape.dims(), ")."); - } - - if (*axis < batch_dims) { - return errors::InvalidArgument("batch_dims (", batch_dims, - ") must be less than or equal to ", - "axis (", *axis, ")."); - } - } - - axis = axis.value_or(0); - DataType index_type = context->input_type(1); - if (index_type != DT_INT32 && index_type != DT_INT64) { - return errors::InvalidArgument("indices must be int32 or int64"); - } - - xla::XlaOp gather; - if (batch_dims > 0) { - *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims); - } else { - // XlaGather() manages degenerate cases, like empty-indices, which are - // error conditions and caught above if batch_dims is not 0. - TF_RETURN_IF_ERROR( - XlaGather(input, input_shape, indices, indices_shape, *axis, - /*indices_are_nd=*/false, context->input_type(0), index_type, - context->builder(), gather_output)); - } - return Status::OK(); -} class GatherOp : public XlaOpKernel { public: explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { @@ -245,11 +164,76 @@ class GatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { auto input = context->Input(0); auto input_shape = context->InputShape(0); + auto indices = context->Input(1); + auto indices_shape = context->InputShape(1); + + absl::optional axis; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), + errors::InvalidArgument("axis must be scalar")); + DataType axis_type = input_type(2); + OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, + errors::InvalidArgument("axis must be int32 or int64")); + + int64 axis_input; + OP_REQUIRES_OK(context, + context->ConstantInputAsIntScalar(2, &axis_input)); + + const auto params_dims = input_shape.dims(); + OP_REQUIRES(context, + -params_dims <= axis_input && axis_input < params_dims, + errors::InvalidArgument("Expected axis in the range [", + -params_dims, ", ", params_dims, + "), but got ", axis_input)); + if (axis_input < 0) { + axis_input += params_dims; + } + axis = axis_input; + } + + if (batch_dims_ != 0) { + if (batch_dims_ < 0) { + batch_dims_ = indices_shape.dims() + batch_dims_; + } + + axis = axis.value_or(batch_dims_); + + OP_REQUIRES(context, + batch_dims_ >= -indices_shape.dims() && + batch_dims_ < indices_shape.dims(), + errors::InvalidArgument("Expected batch_dims in the range [", + -indices_shape.dims(), ", ", + indices_shape.dims(), "), but got ", + batch_dims_)); + + OP_REQUIRES(context, batch_dims_ < input_shape.dims(), + errors::InvalidArgument("batch_dims (", batch_dims_, + ") must be less than rank(input) (", + input_shape.dims(), ").")); + + OP_REQUIRES(context, *axis >= batch_dims_, + errors::InvalidArgument("batch_dims (", batch_dims_, + ") must be less than or equal to ", + "axis (", *axis, ").")); + } + + axis = axis.value_or(0); + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); xla::XlaOp gather; - OP_REQUIRES_OK(context, - XlaGatherWithBatchDimsOpImpl(context, input, input_shape, - batch_dims_, &gather)); + if (batch_dims_ > 0) { + gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_); + } else { + // XlaGather() manages degenerate cases, like empty-indices, which are + // error conditions and caught above if batch_dims is not 0. + OP_REQUIRES_OK( + context, XlaGather(input, input_shape, indices, indices_shape, *axis, + /*indices_are_nd=*/false, input_type(0), + index_type, context->builder(), &gather)); + } context->SetOutput(0, gather); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 7bd25230d46..92346283c31 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -39,13 +39,6 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, DataType index_type, xla::XlaBuilder* builder, xla::XlaOp* gather_output); -// The implementation of Gather and ResourceGather through XLA. Uses `input` as -// the input instead of context->input(0) in order to allow ResourceGather to -// handle obtaining the data from the ResourceVariable. -Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, - const xla::XlaOp input, - const TensorShape& input_shape, - int batch_dims, xla::XlaOp* gather_output); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 60424f85840..7b4125ab76e 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -123,24 +122,27 @@ REGISTER_XLA_OP( class ResourceGatherOp : public XlaOpKernel { public: - explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_)); - } + explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + DataType type = ctx->expected_output_dtype(0); - TensorShape input_shape; - xla::XlaOp input; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input)); + TensorShape resource_shape; + xla::XlaOp resource_handle; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, + &resource_handle)); + auto indices = ctx->Input(1); + auto indices_shape = ctx->InputShape(1); + DataType index_type = ctx->input_type(1); xla::XlaOp gather; - OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape, - batch_dims_, &gather)); + OP_REQUIRES_OK( + ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, index_type, + builder, &gather)); ctx->SetOutput(0, gather); } - - private: - int32 batch_dims_; }; REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 18682e03e59..e5b741b8077 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -861,6 +861,7 @@ cuda_py_test( ], # TODO(b/128347673): Re-enable. tags = ["no_windows"], + xla_enable_strict_auto_jit = True, ) tf_py_test( diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 14a4c531ccc..70c6c7ecfbc 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -986,9 +986,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", container=ops.get_default_graph()._container) - with self.assertRaisesOpError( - "(Resource .*/var5/.* does not exist|Read of uninitialized variable)" - ): + with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval() @test_util.run_deprecated_v1