parent
17099181cd
commit
bd76d3a802
@ -25,10 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -152,85 +150,6 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
|
||||
const xla::XlaOp input,
|
||||
const TensorShape& input_shape,
|
||||
int batch_dims, xla::XlaOp* gather_output) {
|
||||
auto indices = context->Input(1);
|
||||
auto indices_shape = context->InputShape(1);
|
||||
|
||||
absl::optional<int64> axis;
|
||||
if (context->num_inputs() == 3) {
|
||||
const TensorShape axis_shape = context->InputShape(2);
|
||||
if (!TensorShapeUtils::IsScalar(axis_shape)) {
|
||||
return errors::InvalidArgument("axis must be scalar");
|
||||
}
|
||||
DataType axis_type = context->input_type(2);
|
||||
if (axis_type != DT_INT32 && axis_type != DT_INT64) {
|
||||
return errors::InvalidArgument("axis must be int32 or int64");
|
||||
}
|
||||
|
||||
int64 axis_input;
|
||||
TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
|
||||
|
||||
const auto params_dims = input_shape.dims();
|
||||
if (-params_dims > axis_input || axis_input >= params_dims) {
|
||||
return errors::InvalidArgument("Expected axis in the range [",
|
||||
-params_dims, ", ", params_dims,
|
||||
"), but got ", axis_input);
|
||||
}
|
||||
if (axis_input < 0) {
|
||||
axis_input += params_dims;
|
||||
}
|
||||
axis = axis_input;
|
||||
}
|
||||
|
||||
if (batch_dims != 0) {
|
||||
if (batch_dims < 0) {
|
||||
batch_dims = indices_shape.dims() + batch_dims;
|
||||
}
|
||||
|
||||
axis = axis.value_or(batch_dims);
|
||||
|
||||
if (batch_dims < -indices_shape.dims() ||
|
||||
batch_dims >= indices_shape.dims()) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected batch_dims in the range [", -indices_shape.dims(), ", ",
|
||||
indices_shape.dims(), "), but got ", batch_dims);
|
||||
}
|
||||
|
||||
if (batch_dims >= input_shape.dims()) {
|
||||
return errors::InvalidArgument("batch_dims (", batch_dims,
|
||||
") must be less than rank(input) (",
|
||||
input_shape.dims(), ").");
|
||||
}
|
||||
|
||||
if (*axis < batch_dims) {
|
||||
return errors::InvalidArgument("batch_dims (", batch_dims,
|
||||
") must be less than or equal to ",
|
||||
"axis (", *axis, ").");
|
||||
}
|
||||
}
|
||||
|
||||
axis = axis.value_or(0);
|
||||
DataType index_type = context->input_type(1);
|
||||
if (index_type != DT_INT32 && index_type != DT_INT64) {
|
||||
return errors::InvalidArgument("indices must be int32 or int64");
|
||||
}
|
||||
|
||||
xla::XlaOp gather;
|
||||
if (batch_dims > 0) {
|
||||
*gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
|
||||
} else {
|
||||
// XlaGather() manages degenerate cases, like empty-indices, which are
|
||||
// error conditions and caught above if batch_dims is not 0.
|
||||
TF_RETURN_IF_ERROR(
|
||||
XlaGather(input, input_shape, indices, indices_shape, *axis,
|
||||
/*indices_are_nd=*/false, context->input_type(0), index_type,
|
||||
context->builder(), gather_output));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
class GatherOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||
@ -245,11 +164,76 @@ class GatherOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
auto input = context->Input(0);
|
||||
auto input_shape = context->InputShape(0);
|
||||
auto indices = context->Input(1);
|
||||
auto indices_shape = context->InputShape(1);
|
||||
|
||||
absl::optional<int64> axis;
|
||||
if (context->num_inputs() == 3) {
|
||||
const TensorShape axis_shape = context->InputShape(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
|
||||
errors::InvalidArgument("axis must be scalar"));
|
||||
DataType axis_type = input_type(2);
|
||||
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
|
||||
errors::InvalidArgument("axis must be int32 or int64"));
|
||||
|
||||
int64 axis_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->ConstantInputAsIntScalar(2, &axis_input));
|
||||
|
||||
const auto params_dims = input_shape.dims();
|
||||
OP_REQUIRES(context,
|
||||
-params_dims <= axis_input && axis_input < params_dims,
|
||||
errors::InvalidArgument("Expected axis in the range [",
|
||||
-params_dims, ", ", params_dims,
|
||||
"), but got ", axis_input));
|
||||
if (axis_input < 0) {
|
||||
axis_input += params_dims;
|
||||
}
|
||||
axis = axis_input;
|
||||
}
|
||||
|
||||
if (batch_dims_ != 0) {
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices_shape.dims() + batch_dims_;
|
||||
}
|
||||
|
||||
axis = axis.value_or(batch_dims_);
|
||||
|
||||
OP_REQUIRES(context,
|
||||
batch_dims_ >= -indices_shape.dims() &&
|
||||
batch_dims_ < indices_shape.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices_shape.dims(), ", ",
|
||||
indices_shape.dims(), "), but got ",
|
||||
batch_dims_));
|
||||
|
||||
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than rank(input) (",
|
||||
input_shape.dims(), ")."));
|
||||
|
||||
OP_REQUIRES(context, *axis >= batch_dims_,
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than or equal to ",
|
||||
"axis (", *axis, ")."));
|
||||
}
|
||||
|
||||
axis = axis.value_or(0);
|
||||
DataType index_type = input_type(1);
|
||||
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
|
||||
errors::InvalidArgument("indices must be int32 or int64"));
|
||||
|
||||
xla::XlaOp gather;
|
||||
OP_REQUIRES_OK(context,
|
||||
XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
|
||||
batch_dims_, &gather));
|
||||
if (batch_dims_ > 0) {
|
||||
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
|
||||
} else {
|
||||
// XlaGather() manages degenerate cases, like empty-indices, which are
|
||||
// error conditions and caught above if batch_dims is not 0.
|
||||
OP_REQUIRES_OK(
|
||||
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
|
||||
/*indices_are_nd=*/false, input_type(0),
|
||||
index_type, context->builder(), &gather));
|
||||
}
|
||||
context->SetOutput(0, gather);
|
||||
}
|
||||
|
||||
|
@ -39,13 +39,6 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
||||
DataType index_type, xla::XlaBuilder* builder,
|
||||
xla::XlaOp* gather_output);
|
||||
|
||||
// The implementation of Gather and ResourceGather through XLA. Uses `input` as
|
||||
// the input instead of context->input(0) in order to allow ResourceGather to
|
||||
// handle obtaining the data from the ResourceVariable.
|
||||
Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
|
||||
const xla::XlaOp input,
|
||||
const TensorShape& input_shape,
|
||||
int batch_dims, xla::XlaOp* gather_output);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
@ -123,24 +122,27 @@ REGISTER_XLA_OP(
|
||||
|
||||
class ResourceGatherOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_));
|
||||
}
|
||||
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
|
||||
DataType type = ctx->expected_output_dtype(0);
|
||||
|
||||
TensorShape input_shape;
|
||||
xla::XlaOp input;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input));
|
||||
TensorShape resource_shape;
|
||||
xla::XlaOp resource_handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
|
||||
&resource_handle));
|
||||
|
||||
auto indices = ctx->Input(1);
|
||||
auto indices_shape = ctx->InputShape(1);
|
||||
DataType index_type = ctx->input_type(1);
|
||||
xla::XlaOp gather;
|
||||
OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape,
|
||||
batch_dims_, &gather));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
|
||||
/*axis=*/0, /*indices_are_nd=*/false, type, index_type,
|
||||
builder, &gather));
|
||||
ctx->SetOutput(0, gather);
|
||||
}
|
||||
|
||||
private:
|
||||
int32 batch_dims_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
|
||||
|
||||
|
@ -861,6 +861,7 @@ cuda_py_test(
|
||||
],
|
||||
# TODO(b/128347673): Re-enable.
|
||||
tags = ["no_windows"],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
|
@ -986,9 +986,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
x = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
||||
container=ops.get_default_graph()._container)
|
||||
with self.assertRaisesOpError(
|
||||
"(Resource .*/var5/.* does not exist|Read of uninitialized variable)"
|
||||
):
|
||||
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
|
||||
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
Loading…
Reference in New Issue
Block a user