Added forked versions of stateless If and While ops. They should only be used,
when the if then/else body of If or the While body funcs do not have stateful ops. The are lowered to the same XLA ops. One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509 PiperOrigin-RevId: 207977126
This commit is contained in:
parent
f2ff7b1607
commit
8453e23a8b
@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
|
||||
TF_DeleteFunction(func1);
|
||||
}
|
||||
|
||||
// This test only works when the TF build includes XLA compiler. One way to set
|
||||
// this up is via bazel build option "--define with_xla_support=true".
|
||||
//
|
||||
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
|
||||
// something like TENSORFLOW_CAPI_USE_XLA.
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
|
||||
TF_Function* func;
|
||||
const std::string funcName = "BranchFunc";
|
||||
DefineFunction(funcName.c_str(), &func);
|
||||
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* feed = Placeholder(host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_OperationDescription* desc =
|
||||
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
|
||||
TF_AddInput(desc, {true_cond, 0});
|
||||
TF_Output inputs[] = {{feed, 0}};
|
||||
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_SetAttrType(desc, "Tcond", TF_BOOL);
|
||||
TF_DataType inputType = TF_INT32;
|
||||
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
|
||||
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
|
||||
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
|
||||
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
|
||||
TF_SetDevice(desc, "/device:XLA_CPU:0");
|
||||
auto op = TF_FinishOperation(desc, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
ASSERT_NE(op, nullptr);
|
||||
|
||||
// Create a session for this graph.
|
||||
CSession csession(host_graph_, s_, /*use_XLA*/ true);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
// Run the graph.
|
||||
csession.SetInputs({{feed, Int32Tensor(17)}});
|
||||
csession.SetOutputs({op});
|
||||
csession.Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_Tensor* out = csession.output_tensor(0);
|
||||
ASSERT_TRUE(out != nullptr);
|
||||
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
||||
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
||||
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
||||
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
||||
EXPECT_EQ(-17, *output_contents);
|
||||
|
||||
// Clean up
|
||||
csession.CloseAndDelete(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_DeleteFunction(func);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -26,6 +26,10 @@ limitations under the License.
|
||||
using tensorflow::GraphDef;
|
||||
using tensorflow::NodeDef;
|
||||
|
||||
static void BoolDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<bool*>(data);
|
||||
}
|
||||
|
||||
static void Int32Deallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<int32_t*>(data);
|
||||
}
|
||||
@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
TF_Tensor* BoolTensor(bool v) {
|
||||
const int num_bytes = sizeof(bool);
|
||||
bool* values = new bool[1];
|
||||
values[0] = v;
|
||||
return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
|
||||
int64_t num_values = 1;
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
|
||||
return Const(tensor.get(), graph, s, name);
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
|
||||
|
@ -31,6 +31,8 @@ using ::tensorflow::string;
|
||||
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
|
||||
unique_tensor_ptr;
|
||||
|
||||
TF_Tensor* BoolTensor(int32_t v);
|
||||
|
||||
// Create a tensor with values of type TF_INT8 provided by `values`.
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
|
||||
|
||||
@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
|
||||
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "const");
|
||||
|
||||
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "scalar");
|
||||
|
||||
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "scalar");
|
||||
|
||||
|
@ -247,6 +247,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp);
|
||||
REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp);
|
||||
REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
|
||||
REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp);
|
||||
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
43
tensorflow/core/api_def/base_api/api_def_StatelessIf.pbtxt
Normal file
43
tensorflow/core/api_def/base_api/api_def_StatelessIf.pbtxt
Normal file
@ -0,0 +1,43 @@
|
||||
op {
|
||||
graph_op_name: "StatelessIf"
|
||||
in_arg { name: "cond" description: "The predicate." }
|
||||
in_arg {
|
||||
name: "cond"
|
||||
description: <<END
|
||||
A Tensor. If the tensor is a scalar of non-boolean type, the
|
||||
scalar is converted to a boolean according to the
|
||||
following rule: if the scalar is a numerical value, non-zero means
|
||||
`True` and zero means False; if the scalar is a string, non-empty
|
||||
means `True` and empty means `False`. If the tensor is not a scalar,
|
||||
being empty means False and being non-empty means True.
|
||||
|
||||
This should only be used when the if then/else body functions do not
|
||||
have stateful ops.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: "A list of input tensors."
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: "A list of return values."
|
||||
}
|
||||
attr { name: "Tin" description: "A list of input types." }
|
||||
attr { name: "Tout" description: "A list of output types." }
|
||||
attr {
|
||||
name: "then_branch"
|
||||
description: <<END
|
||||
A function that takes 'inputs' and returns a list of tensors, whose
|
||||
types are the same as what else_branch returns.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "else_branch"
|
||||
description: <<END
|
||||
A function that takes 'inputs' and returns a list of tensors, whose
|
||||
types are the same as what then_branch returns.
|
||||
END
|
||||
}
|
||||
summary: "output = cond ? then_branch(input) : else_branch(input)"
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "StatelessWhile"
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: "A list of input tensors whose types are T."
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: "A list of output tensors whose types are T."
|
||||
}
|
||||
attr { name: "T" description: "dtype in use." }
|
||||
attr {
|
||||
name: "cond"
|
||||
description: <<END
|
||||
A function takes 'input' and returns a tensor. If the tensor is
|
||||
a scalar of non-boolean, the scalar is converted to a boolean
|
||||
according to the following rule: if the scalar is a numerical
|
||||
value, non-zero means True and zero means False; if the scalar is
|
||||
a string, non-empty means True and empty means False. If the
|
||||
tensor is not a scalar, non-emptiness means True and False
|
||||
otherwise.
|
||||
|
||||
This should only be used when the while condition and body functions
|
||||
do not have stateful ops.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "body"
|
||||
description: <<END
|
||||
A function that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types as specified
|
||||
by T.
|
||||
END
|
||||
}
|
||||
summary: "output = input; While (Cond(output)) { output = Body(output) }"
|
||||
}
|
@ -0,0 +1 @@
|
||||
op { graph_op_name: "StatelessIf" visibility: HIDDEN }
|
@ -0,0 +1 @@
|
||||
op { graph_op_name: "StatelessWhile" visibility: HIDDEN }
|
@ -218,6 +218,10 @@ REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
|
||||
REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
|
||||
|
||||
class WhileOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
@ -379,6 +383,9 @@ REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
|
||||
|
||||
Status GetScalar(OpKernelContext* ctx, int index, int32* value,
|
||||
const char* label) {
|
||||
Tensor t = ctx->input(index);
|
||||
|
@ -90,6 +90,17 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
tensors. whose types are the same as what then_branch returns.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("StatelessIf")
|
||||
.Input("cond: Tcond")
|
||||
.Input("input: Tin")
|
||||
.Output("output: Tout")
|
||||
.Attr("Tcond: type")
|
||||
.Attr("Tin: list(type) >= 0")
|
||||
.Attr("Tout: list(type) >= 0")
|
||||
.Attr("then_branch: func")
|
||||
.Attr("else_branch: func")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("If")
|
||||
.Input("cond: Tcond")
|
||||
.Input("input: Tin")
|
||||
@ -133,8 +144,6 @@ body: A function that takes a list of tensors and returns another
|
||||
by T.
|
||||
)doc");
|
||||
|
||||
// TODO(b/37549631) setting the While Op to always be stateful is too
|
||||
// conservative.
|
||||
REGISTER_OP("While")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
@ -149,6 +158,19 @@ REGISTER_OP("While")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("StatelessWhile")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("For")
|
||||
.Input("start: int32")
|
||||
.Input("limit: int32")
|
||||
|
Loading…
Reference in New Issue
Block a user