308 lines
11 KiB
C++
308 lines
11 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
|
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
|
#include "tensorflow/core/framework/types.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
class VarIsInitializedOp : public XlaOpKernel {
|
|
public:
|
|
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
XlaResource* variable;
|
|
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
|
|
ctx->SetOutput(
|
|
0, xla::ConstantR0<bool>(ctx->builder(), variable->initialized()));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
|
|
|
|
class VariableShapeOp : public XlaOpKernel {
|
|
public:
|
|
explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
DataType variable_dtype;
|
|
TensorShape shape;
|
|
OP_REQUIRES_OK(ctx,
|
|
ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape));
|
|
Tensor shape_constant(out_dtype_, TensorShape({shape.dims()}));
|
|
OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant));
|
|
ctx->SetConstantOutput(0, shape_constant);
|
|
}
|
|
|
|
private:
|
|
DataType out_dtype_;
|
|
};
|
|
REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp);
|
|
|
|
class ReadVariableOp : public XlaOpKernel {
|
|
public:
|
|
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
xla::XlaOp handle;
|
|
OP_REQUIRES_OK(
|
|
ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
|
|
ctx->SetOutput(0, handle);
|
|
}
|
|
|
|
private:
|
|
DataType dtype_;
|
|
};
|
|
REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp);
|
|
|
|
class AssignVariableOp : public XlaOpKernel {
|
|
public:
|
|
explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
OP_REQUIRES_OK(ctx,
|
|
ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1)));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp);
|
|
|
|
class AssignAddVariableOp : public XlaOpKernel {
|
|
public:
|
|
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
DataType type = ctx->input_type(1);
|
|
xla::XlaOp handle;
|
|
OP_REQUIRES_OK(ctx,
|
|
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
|
|
handle = xla::Add(handle, ctx->Input(1));
|
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(
|
|
Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes),
|
|
AssignAddVariableOp);
|
|
|
|
class AssignSubVariableOp : public XlaOpKernel {
|
|
public:
|
|
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
DataType type = ctx->input_type(1);
|
|
xla::XlaOp handle;
|
|
OP_REQUIRES_OK(ctx,
|
|
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
|
|
handle = xla::Sub(handle, ctx->Input(1));
|
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(
|
|
Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes),
|
|
AssignSubVariableOp);
|
|
|
|
class ResourceGatherOp : public XlaOpKernel {
|
|
public:
|
|
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_dims", &batch_dims_));
|
|
}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
DataType type = ctx->expected_output_dtype(0);
|
|
|
|
TensorShape input_shape;
|
|
xla::XlaOp input;
|
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &input_shape, &input));
|
|
|
|
xla::XlaOp gather;
|
|
OP_REQUIRES_OK(ctx, XlaGatherWithBatchDimsOpImpl(ctx, input, input_shape,
|
|
batch_dims_, &gather));
|
|
ctx->SetOutput(0, gather);
|
|
}
|
|
|
|
private:
|
|
int32 batch_dims_;
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
|
|
|
|
class ResourceScatterOp : public XlaOpKernel {
|
|
public:
|
|
explicit ResourceScatterOp(
|
|
OpKernelConstruction* context, bool indices_are_vectors,
|
|
std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
|
|
xla::XlaBuilder*)>
|
|
combiner)
|
|
: XlaOpKernel(context),
|
|
indices_are_vectors_(indices_are_vectors),
|
|
combiner_(std::move(combiner)) {}
|
|
|
|
void Compile(XlaOpKernelContext* context) override {
|
|
xla::XlaBuilder* builder = context->builder();
|
|
|
|
DataType dtype = context->input_type(2);
|
|
TensorShape var_shape;
|
|
xla::XlaOp var_value;
|
|
OP_REQUIRES_OK(
|
|
context, context->ReadVariableInput(0, dtype, &var_shape, &var_value));
|
|
|
|
const xla::XlaOp indices = context->Input(1);
|
|
const xla::XlaOp updates = context->Input(2);
|
|
|
|
auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_,
|
|
combiner_, builder);
|
|
OP_REQUIRES_OK(context, result.status());
|
|
OP_REQUIRES_OK(context,
|
|
context->AssignVariable(0, dtype, result.ValueOrDie()));
|
|
}
|
|
|
|
private:
|
|
const bool indices_are_vectors_;
|
|
const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
|
|
xla::XlaBuilder*)>
|
|
combiner_;
|
|
};
|
|
|
|
class ResourceScatterAddOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterAddOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Add(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp);
|
|
|
|
class ResourceScatterSubOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterSubOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Sub(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp);
|
|
|
|
class ResourceScatterMulOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterMulOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Mul(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp);
|
|
|
|
class ResourceScatterDivOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterDivOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Div(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp);
|
|
|
|
class ResourceScatterMinOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterMinOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Min(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp);
|
|
|
|
class ResourceScatterMaxOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterMaxOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Max(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp);
|
|
|
|
class ResourceScatterUpdateOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterUpdateOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/false,
|
|
/*combiner=*/{}) {}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp);
|
|
|
|
class ResourceScatterNdUpdateOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/true,
|
|
/*combiner=*/{}) {}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp);
|
|
|
|
class ResourceScatterNdAddOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterNdAddOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/true,
|
|
/*combiner=*/Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Add(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp);
|
|
|
|
class ResourceScatterNdSubOp : public ResourceScatterOp {
|
|
public:
|
|
explicit ResourceScatterNdSubOp(OpKernelConstruction* context)
|
|
: ResourceScatterOp(context, /*indices_are_vectors=*/true,
|
|
/*combiner=*/Combine) {}
|
|
|
|
private:
|
|
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
|
|
xla::XlaBuilder* builder) {
|
|
return xla::Sub(x, y);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp);
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|