[TF:XLA] Support batch_dims in ResourceGatherOp.
PiperOrigin-RevId: 266950621
This commit is contained in:
parent
67e80579eb
commit
ce2b635fcd
@ -25,8 +25,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -150,6 +152,85 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
|
||||||
|
const xla::XlaOp input,
|
||||||
|
const TensorShape& input_shape,
|
||||||
|
int batch_dims, xla::XlaOp* gather_output) {
|
||||||
|
auto indices = context->Input(1);
|
||||||
|
auto indices_shape = context->InputShape(1);
|
||||||
|
|
||||||
|
absl::optional<int64> axis;
|
||||||
|
if (context->num_inputs() == 3) {
|
||||||
|
const TensorShape axis_shape = context->InputShape(2);
|
||||||
|
if (!TensorShapeUtils::IsScalar(axis_shape)) {
|
||||||
|
return errors::InvalidArgument("axis must be scalar");
|
||||||
|
}
|
||||||
|
DataType axis_type = context->input_type(2);
|
||||||
|
if (axis_type != DT_INT32 && axis_type != DT_INT64) {
|
||||||
|
return errors::InvalidArgument("axis must be int32 or int64");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 axis_input;
|
||||||
|
TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
|
||||||
|
|
||||||
|
const auto params_dims = input_shape.dims();
|
||||||
|
if (-params_dims > axis_input || axis_input >= params_dims) {
|
||||||
|
return errors::InvalidArgument("Expected axis in the range [",
|
||||||
|
-params_dims, ", ", params_dims,
|
||||||
|
"), but got ", axis_input);
|
||||||
|
}
|
||||||
|
if (axis_input < 0) {
|
||||||
|
axis_input += params_dims;
|
||||||
|
}
|
||||||
|
axis = axis_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch_dims != 0) {
|
||||||
|
if (batch_dims < 0) {
|
||||||
|
batch_dims = indices_shape.dims() + batch_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
axis = axis.value_or(batch_dims);
|
||||||
|
|
||||||
|
if (batch_dims < -indices_shape.dims() ||
|
||||||
|
batch_dims >= indices_shape.dims()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected batch_dims in the range [", -indices_shape.dims(), ", ",
|
||||||
|
indices_shape.dims(), "), but got ", batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch_dims >= input_shape.dims()) {
|
||||||
|
return errors::InvalidArgument("batch_dims (", batch_dims,
|
||||||
|
") must be less than rank(input) (",
|
||||||
|
input_shape.dims(), ").");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*axis < batch_dims) {
|
||||||
|
return errors::InvalidArgument("batch_dims (", batch_dims,
|
||||||
|
") must be less than or equal to ",
|
||||||
|
"axis (", *axis, ").");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
axis = axis.value_or(0);
|
||||||
|
DataType index_type = context->input_type(1);
|
||||||
|
if (index_type != DT_INT32 && index_type != DT_INT64) {
|
||||||
|
return errors::InvalidArgument("indices must be int32 or int64");
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp gather;
|
||||||
|
if (batch_dims > 0) {
|
||||||
|
*gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
|
||||||
|
} else {
|
||||||
|
// XlaGather() manages degenerate cases, like empty-indices, which are
|
||||||
|
// error conditions and caught above if batch_dims is not 0.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
XlaGather(input, input_shape, indices, indices_shape, *axis,
|
||||||
|
/*indices_are_nd=*/false, context->input_type(0), index_type,
|
||||||
|
context->builder(), gather_output));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
class GatherOp : public XlaOpKernel {
|
class GatherOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||||
@ -164,76 +245,11 @@ class GatherOp : public XlaOpKernel {
|
|||||||
void Compile(XlaOpKernelContext* context) override {
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
auto input = context->Input(0);
|
auto input = context->Input(0);
|
||||||
auto input_shape = context->InputShape(0);
|
auto input_shape = context->InputShape(0);
|
||||||
auto indices = context->Input(1);
|
|
||||||
auto indices_shape = context->InputShape(1);
|
|
||||||
|
|
||||||
absl::optional<int64> axis;
|
|
||||||
if (context->num_inputs() == 3) {
|
|
||||||
const TensorShape axis_shape = context->InputShape(2);
|
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
|
|
||||||
errors::InvalidArgument("axis must be scalar"));
|
|
||||||
DataType axis_type = input_type(2);
|
|
||||||
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
|
|
||||||
errors::InvalidArgument("axis must be int32 or int64"));
|
|
||||||
|
|
||||||
int64 axis_input;
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->ConstantInputAsIntScalar(2, &axis_input));
|
|
||||||
|
|
||||||
const auto params_dims = input_shape.dims();
|
|
||||||
OP_REQUIRES(context,
|
|
||||||
-params_dims <= axis_input && axis_input < params_dims,
|
|
||||||
errors::InvalidArgument("Expected axis in the range [",
|
|
||||||
-params_dims, ", ", params_dims,
|
|
||||||
"), but got ", axis_input));
|
|
||||||
if (axis_input < 0) {
|
|
||||||
axis_input += params_dims;
|
|
||||||
}
|
|
||||||
axis = axis_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (batch_dims_ != 0) {
|
|
||||||
if (batch_dims_ < 0) {
|
|
||||||
batch_dims_ = indices_shape.dims() + batch_dims_;
|
|
||||||
}
|
|
||||||
|
|
||||||
axis = axis.value_or(batch_dims_);
|
|
||||||
|
|
||||||
OP_REQUIRES(context,
|
|
||||||
batch_dims_ >= -indices_shape.dims() &&
|
|
||||||
batch_dims_ < indices_shape.dims(),
|
|
||||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
|
||||||
-indices_shape.dims(), ", ",
|
|
||||||
indices_shape.dims(), "), but got ",
|
|
||||||
batch_dims_));
|
|
||||||
|
|
||||||
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
|
|
||||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
|
||||||
") must be less than rank(input) (",
|
|
||||||
input_shape.dims(), ")."));
|
|
||||||
|
|
||||||
OP_REQUIRES(context, *axis >= batch_dims_,
|
|
||||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
|
||||||
") must be less than or equal to ",
|
|
||||||
"axis (", *axis, ")."));
|
|
||||||
}
|
|
||||||
|
|
||||||
axis = axis.value_or(0);
|
|
||||||
DataType index_type = input_type(1);
|
|
||||||
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
|
|
||||||
errors::InvalidArgument("indices must be int32 or int64"));
|
|
||||||
|
|
||||||
xla::XlaOp gather;
|
xla::XlaOp gather;
|
||||||
if (batch_dims_ > 0) {
|
OP_REQUIRES_OK(context,
|
||||||
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
|
XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
|
||||||
} else {
|
batch_dims_, &gather));
|
||||||
// XlaGather() manages degenerate cases, like empty-indices, which are
|
|
||||||
// error conditions and caught above if batch_dims is not 0.
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
|
|
||||||
/*indices_are_nd=*/false, input_type(0),
|
|
||||||
index_type, context->builder(), &gather));
|
|
||||||
}
|
|
||||||
context->SetOutput(0, gather);
|
context->SetOutput(0, gather);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,6 +39,13 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
|||||||
DataType index_type, xla::XlaBuilder* builder,
|
DataType index_type, xla::XlaBuilder* builder,
|
||||||
xla::XlaOp* gather_output);
|
xla::XlaOp* gather_output);
|
||||||
|
|
||||||
|
// The implementation of Gather and ResourceGather through XLA. Uses `input` as
|
||||||
|
// the input instead of context->input(0) in order to allow ResourceGather to
|
||||||
|
// handle obtaining the data from the ResourceVariable.
|
||||||
|
Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
|
||||||
|
const xla::XlaOp input,
|
||||||
|
const TensorShape& input_shape,
|
||||||
|
int batch_dims, xla::XlaOp* gather_output);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_
|
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
@ -122,27 +123,24 @@ REGISTER_XLA_OP(
|
|||||||
|
|
||||||
class ResourceGatherOp : public XlaOpKernel {
|
class ResourceGatherOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_));
|
||||||
|
}
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaBuilder* builder = ctx->builder();
|
|
||||||
|
|
||||||
DataType type = ctx->expected_output_dtype(0);
|
DataType type = ctx->expected_output_dtype(0);
|
||||||
|
|
||||||
TensorShape resource_shape;
|
TensorShape input_shape;
|
||||||
xla::XlaOp resource_handle;
|
xla::XlaOp input;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input));
|
||||||
&resource_handle));
|
|
||||||
|
|
||||||
auto indices = ctx->Input(1);
|
|
||||||
auto indices_shape = ctx->InputShape(1);
|
|
||||||
DataType index_type = ctx->input_type(1);
|
|
||||||
xla::XlaOp gather;
|
xla::XlaOp gather;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape,
|
||||||
ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
|
batch_dims_, &gather));
|
||||||
/*axis=*/0, /*indices_are_nd=*/false, type, index_type,
|
|
||||||
builder, &gather));
|
|
||||||
ctx->SetOutput(0, gather);
|
ctx->SetOutput(0, gather);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 batch_dims_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
|
REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
|
||||||
|
|
||||||
|
@ -861,7 +861,6 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
# TODO(b/128347673): Re-enable.
|
# TODO(b/128347673): Re-enable.
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
xla_enable_strict_auto_jit = True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
|
@ -986,7 +986,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
x = resource_variable_ops.var_handle_op(
|
x = resource_variable_ops.var_handle_op(
|
||||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
||||||
container=ops.get_default_graph()._container)
|
container=ops.get_default_graph()._container)
|
||||||
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
|
with self.assertRaisesOpError(
|
||||||
|
"(Resource .*/var5/.* does not exist|Read of uninitialized variable)"
|
||||||
|
):
|
||||||
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
|
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
Loading…
Reference in New Issue
Block a user