Added forked versions of stateless If and While ops. They should only be used,

when the if then/else body of If or the While body funcs do not have stateful
ops.

The are lowered to the same XLA ops.

One use case is in the S4TF compiler: https://github.com/apple/swift/pull/18509

PiperOrigin-RevId: 207977126
This commit is contained in:
Mingsheng Hong 2018-08-08 18:00:05 -07:00 committed by TensorFlower Gardener
parent f2ff7b1607
commit 8453e23a8b
11 changed files with 198 additions and 2 deletions

View File

@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1);
}
// This test only works when the TF build includes XLA compiler. One way to set
// this up is via bazel build option "--define with_xla_support=true".
//
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
// something like TENSORFLOW_CAPI_USE_XLA.
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
TF_Function* func;
const std::string funcName = "BranchFunc";
DefineFunction(funcName.c_str(), &func);
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* feed = Placeholder(host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_OperationDescription* desc =
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
TF_AddInput(desc, {true_cond, 0});
TF_Output inputs[] = {{feed, 0}};
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SetAttrType(desc, "Tcond", TF_BOOL);
TF_DataType inputType = TF_INT32;
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
TF_SetDevice(desc, "/device:XLA_CPU:0");
auto op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(op, nullptr);
// Create a session for this graph.
CSession csession(host_graph_, s_, /*use_XLA*/ true);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run the graph.
csession.SetInputs({{feed, Int32Tensor(17)}});
csession.SetOutputs({op});
csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(-17, *output_contents);
// Clean up
csession.CloseAndDelete(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_DeleteFunction(func);
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
} // namespace tensorflow

View File

@ -26,6 +26,10 @@ limitations under the License.
using tensorflow::GraphDef;
using tensorflow::NodeDef;
static void BoolDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<bool*>(data);
}
static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data);
}
@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
TF_Tensor* BoolTensor(bool v) {
const int num_bytes = sizeof(bool);
bool* values = new bool[1];
values[0] = v;
return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
nullptr);
}
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
return op;
}
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);

View File

@ -31,6 +31,8 @@ using ::tensorflow::string;
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
unique_tensor_ptr;
TF_Tensor* BoolTensor(int32_t v);
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const");
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");

View File

@ -247,6 +247,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
} // namespace tensorflow

View File

@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
} // namespace tensorflow

View File

@ -0,0 +1,43 @@
op {
graph_op_name: "StatelessIf"
in_arg { name: "cond" description: "The predicate." }
in_arg {
name: "cond"
description: <<END
A Tensor. If the tensor is a scalar of non-boolean type, the
scalar is converted to a boolean according to the
following rule: if the scalar is a numerical value, non-zero means
`True` and zero means False; if the scalar is a string, non-empty
means `True` and empty means `False`. If the tensor is not a scalar,
being empty means False and being non-empty means True.
This should only be used when the if then/else body functions do not
have stateful ops.
END
}
in_arg {
name: "input"
description: "A list of input tensors."
}
out_arg {
name: "output"
description: "A list of return values."
}
attr { name: "Tin" description: "A list of input types." }
attr { name: "Tout" description: "A list of output types." }
attr {
name: "then_branch"
description: <<END
A function that takes 'inputs' and returns a list of tensors, whose
types are the same as what else_branch returns.
END
}
attr {
name: "else_branch"
description: <<END
A function that takes 'inputs' and returns a list of tensors, whose
types are the same as what then_branch returns.
END
}
summary: "output = cond ? then_branch(input) : else_branch(input)"
}

View File

@ -0,0 +1,36 @@
op {
graph_op_name: "StatelessWhile"
in_arg {
name: "input"
description: "A list of input tensors whose types are T."
}
out_arg {
name: "output"
description: "A list of output tensors whose types are T."
}
attr { name: "T" description: "dtype in use." }
attr {
name: "cond"
description: <<END
A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
This should only be used when the while condition and body functions
do not have stateful ops.
END
}
attr {
name: "body"
description: <<END
A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified
by T.
END
}
summary: "output = input; While (Cond(output)) { output = Body(output) }"
}

View File

@ -0,0 +1 @@
op { graph_op_name: "StatelessIf" visibility: HIDDEN }

View File

@ -0,0 +1 @@
op { graph_op_name: "StatelessWhile" visibility: HIDDEN }

View File

@ -218,6 +218,10 @@ REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
REGISTER_KERNEL_BUILDER(
Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
class WhileOp : public AsyncOpKernel {
public:
explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
@ -379,6 +383,9 @@ REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
Status GetScalar(OpKernelContext* ctx, int index, int32* value,
const char* label) {
Tensor t = ctx->input(index);

View File

@ -90,6 +90,17 @@ else_branch: A function that takes 'inputs' and returns a list of
tensors. whose types are the same as what then_branch returns.
)doc");
REGISTER_OP("StatelessIf")
.Input("cond: Tcond")
.Input("input: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("If")
.Input("cond: Tcond")
.Input("input: Tin")
@ -133,8 +144,6 @@ body: A function that takes a list of tensors and returns another
by T.
)doc");
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("While")
.Input("input: T")
.Output("output: T")
@ -149,6 +158,19 @@ REGISTER_OP("While")
return Status::OK();
});
REGISTER_OP("StatelessWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->input(i));
}
return Status::OK();
});
REGISTER_OP("For")
.Input("start: int32")
.Input("limit: int32")