Automated rollback of commit ce2b635fcd

PiperOrigin-RevId: 266990283
This commit is contained in:
Sanjoy Das 2019-09-03 12:40:27 -07:00 committed by TensorFlower Gardener
parent 17099181cd
commit bd76d3a802
5 changed files with 84 additions and 106 deletions

View File

@ -25,10 +25,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@ -152,85 +150,6 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
return Status::OK();
}
Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
const xla::XlaOp input,
const TensorShape& input_shape,
int batch_dims, xla::XlaOp* gather_output) {
auto indices = context->Input(1);
auto indices_shape = context->InputShape(1);
absl::optional<int64> axis;
if (context->num_inputs() == 3) {
const TensorShape axis_shape = context->InputShape(2);
if (!TensorShapeUtils::IsScalar(axis_shape)) {
return errors::InvalidArgument("axis must be scalar");
}
DataType axis_type = context->input_type(2);
if (axis_type != DT_INT32 && axis_type != DT_INT64) {
return errors::InvalidArgument("axis must be int32 or int64");
}
int64 axis_input;
TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
const auto params_dims = input_shape.dims();
if (-params_dims > axis_input || axis_input >= params_dims) {
return errors::InvalidArgument("Expected axis in the range [",
-params_dims, ", ", params_dims,
"), but got ", axis_input);
}
if (axis_input < 0) {
axis_input += params_dims;
}
axis = axis_input;
}
if (batch_dims != 0) {
if (batch_dims < 0) {
batch_dims = indices_shape.dims() + batch_dims;
}
axis = axis.value_or(batch_dims);
if (batch_dims < -indices_shape.dims() ||
batch_dims >= indices_shape.dims()) {
return errors::InvalidArgument(
"Expected batch_dims in the range [", -indices_shape.dims(), ", ",
indices_shape.dims(), "), but got ", batch_dims);
}
if (batch_dims >= input_shape.dims()) {
return errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than rank(input) (",
input_shape.dims(), ").");
}
if (*axis < batch_dims) {
return errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than or equal to ",
"axis (", *axis, ").");
}
}
axis = axis.value_or(0);
DataType index_type = context->input_type(1);
if (index_type != DT_INT32 && index_type != DT_INT64) {
return errors::InvalidArgument("indices must be int32 or int64");
}
xla::XlaOp gather;
if (batch_dims > 0) {
*gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
} else {
// XlaGather() manages degenerate cases, like empty-indices, which are
// error conditions and caught above if batch_dims is not 0.
TF_RETURN_IF_ERROR(
XlaGather(input, input_shape, indices, indices_shape, *axis,
/*indices_are_nd=*/false, context->input_type(0), index_type,
context->builder(), gather_output));
}
return Status::OK();
}
class GatherOp : public XlaOpKernel {
public:
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
@ -245,11 +164,76 @@ class GatherOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* context) override {
auto input = context->Input(0);
auto input_shape = context->InputShape(0);
auto indices = context->Input(1);
auto indices_shape = context->InputShape(1);
absl::optional<int64> axis;
if (context->num_inputs() == 3) {
const TensorShape axis_shape = context->InputShape(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
errors::InvalidArgument("axis must be scalar"));
DataType axis_type = input_type(2);
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
errors::InvalidArgument("axis must be int32 or int64"));
int64 axis_input;
OP_REQUIRES_OK(context,
context->ConstantInputAsIntScalar(2, &axis_input));
const auto params_dims = input_shape.dims();
OP_REQUIRES(context,
-params_dims <= axis_input && axis_input < params_dims,
errors::InvalidArgument("Expected axis in the range [",
-params_dims, ", ", params_dims,
"), but got ", axis_input));
if (axis_input < 0) {
axis_input += params_dims;
}
axis = axis_input;
}
if (batch_dims_ != 0) {
if (batch_dims_ < 0) {
batch_dims_ = indices_shape.dims() + batch_dims_;
}
axis = axis.value_or(batch_dims_);
OP_REQUIRES(context,
batch_dims_ >= -indices_shape.dims() &&
batch_dims_ < indices_shape.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices_shape.dims(), ", ",
indices_shape.dims(), "), but got ",
batch_dims_));
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than rank(input) (",
input_shape.dims(), ")."));
OP_REQUIRES(context, *axis >= batch_dims_,
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than or equal to ",
"axis (", *axis, ")."));
}
axis = axis.value_or(0);
DataType index_type = input_type(1);
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("indices must be int32 or int64"));
xla::XlaOp gather;
OP_REQUIRES_OK(context,
XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
batch_dims_, &gather));
if (batch_dims_ > 0) {
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
} else {
// XlaGather() manages degenerate cases, like empty-indices, which are
// error conditions and caught above if batch_dims is not 0.
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
/*indices_are_nd=*/false, input_type(0),
index_type, context->builder(), &gather));
}
context->SetOutput(0, gather);
}

View File

@ -39,13 +39,6 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
DataType index_type, xla::XlaBuilder* builder,
xla::XlaOp* gather_output);
// The implementation of Gather and ResourceGather through XLA. Uses `input` as
// the input instead of context->input(0) in order to allow ResourceGather to
// handle obtaining the data from the ResourceVariable.
Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
const xla::XlaOp input,
const TensorShape& input_shape,
int batch_dims, xla::XlaOp* gather_output);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@ -123,24 +122,27 @@ REGISTER_XLA_OP(
class ResourceGatherOp : public XlaOpKernel {
public:
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_));
}
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* builder = ctx->builder();
DataType type = ctx->expected_output_dtype(0);
TensorShape input_shape;
xla::XlaOp input;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input));
TensorShape resource_shape;
xla::XlaOp resource_handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
&resource_handle));
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
DataType index_type = ctx->input_type(1);
xla::XlaOp gather;
OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape,
batch_dims_, &gather));
OP_REQUIRES_OK(
ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
/*axis=*/0, /*indices_are_nd=*/false, type, index_type,
builder, &gather));
ctx->SetOutput(0, gather);
}
private:
int32 batch_dims_;
};
REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);

View File

@ -861,6 +861,7 @@ cuda_py_test(
],
# TODO(b/128347673): Re-enable.
tags = ["no_windows"],
xla_enable_strict_auto_jit = True,
)
tf_py_test(

View File

@ -986,9 +986,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
container=ops.get_default_graph()._container)
with self.assertRaisesOpError(
"(Resource .*/var5/.* does not exist|Read of uninitialized variable)"
):
with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
@test_util.run_deprecated_v1