Move bitcast registration into C

This change makes the bitcast op use the new C op API. The purpose of
this change is to demonstrate that op registration and shape inference
can be performed using an ABI-stable API. This is a prerequisite to
being able to load ops from external libraries in the modular TF design.

PiperOrigin-RevId: 255720454
This commit is contained in:
James Ring 2019-06-28 22:17:46 -07:00 committed by TensorFlower Gardener
parent 5fbe995c00
commit 05ad600619
11 changed files with 269 additions and 80 deletions

View File

@ -46,7 +46,7 @@ filegroup(
"*test*",
],
) + [
"//tensorflow/cc:srcs",
"//tensorflow/cc:srcs_no_runtime",
"//tensorflow/core/distributed_runtime:server_lib.h",
],
visibility = ["//visibility:public"],
@ -202,6 +202,7 @@ cc_library(
"//tensorflow/core:framework",
],
}),
alwayslink = 1,
)
cc_library(
@ -389,6 +390,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
],
}),
alwayslink = 1,
)
# -----------------------------------------------------------------------------

View File

@ -2,6 +2,7 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_kernel_library",
"tf_gen_op_libs",
)
package(
@ -23,6 +24,17 @@ tf_kernel_library(
],
)
tf_gen_op_libs(
op_lib_names = ["bitcast"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"],
@ -44,9 +56,14 @@ tf_cc_test(
#
# LINT.IfChange
filegroup(
name = "android_all_ops",
name = "android_all_op_kernels",
srcs = [
"bitcast_op.cc",
],
)
# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt)
filegroup(
name = "android_all_ops",
srcs = ["ops/bitcast.cc"],
)

View File

@ -136,7 +136,7 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
TF_DeleteTensor(tensor);
}
void RegisterBitcastOp() {
void RegisterBitcastOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
@ -163,9 +163,9 @@ void RegisterBitcastOp() {
// A dummy static variable initialized by a lambda whose side-effect is to
// register the bitcast kernel.
static bool BitcastOpIsRegistered = []() {
static bool IsBitcastOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) {
RegisterBitcastOp();
RegisterBitcastOpKernel();
}
return true;
}();

View File

@ -15,8 +15,11 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -97,5 +100,64 @@ TEST(BitcastOpTest, TestImpossibleCast) {
TestBitcastOp(&int8_input, DT_UINT32, TensorShape(), error::INVALID_ARGUMENT);
}
PartialTensorShape S(std::initializer_list<int64> dims) {
return PartialTensorShape(dims);
}
TEST(BitcastOpTest, TestShapeInference_LargerShape) {
const OpRegistrationData* reg;
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", &reg));
OpDef op_def = reg->op_def;
NodeDef def;
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
.Attr("type", DT_INT8)
.Attr("T", DT_INT64)
.Input(FakeInput(DT_INT64))
.Finalize(&def));
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_CHECK_OK(c.input("input", &input_shapes));
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
TF_CHECK_OK(reg->shape_inference_fn(&c));
ASSERT_EQ("[3,4,8]", c.DebugString(c.output(0)));
}
TEST(BitcastOpTest, TestShapeInference_SmallerShape) {
const OpRegistrationData* reg;
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", &reg));
OpDef op_def = reg->op_def;
NodeDef def;
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
.Attr("type", DT_INT64)
.Attr("T", DT_INT8)
.Input(FakeInput(DT_INT8))
.Finalize(&def));
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {},
{});
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_CHECK_OK(c.input("input", &input_shapes));
ASSERT_EQ("[3,4,8]", c.DebugString(input_shapes[0]));
TF_CHECK_OK(reg->shape_inference_fn(&c));
ASSERT_EQ("[3,4]", c.DebugString(c.output(0)));
}
TEST(BitcastOpTest, TestShapeInference_SameShape) {
const OpRegistrationData* reg;
TF_CHECK_OK(OpRegistry::Global()->LookUp("Bitcast", &reg));
OpDef op_def = reg->op_def;
NodeDef def;
TF_CHECK_OK(NodeDefBuilder("dummy", &op_def)
.Attr("type", DT_INT32)
.Attr("T", DT_FLOAT)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def));
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_CHECK_OK(c.input("input", &input_shapes));
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
TF_CHECK_OK(reg->shape_inference_fn(&c));
ASSERT_EQ("[3,4]", c.DebugString(c.output(0)));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,135 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "tensorflow/c/ops.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
static void ComputeNewShape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape, size_t input_type_size,
size_t output_type_size, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
if (input_type_size < output_type_size) {
TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status);
if (TF_GetCode(status) == TF_OK) {
TF_DimensionHandle* last_dim = TF_NewDimensionHandle();
size_t divisor_val = output_type_size / input_type_size;
TF_ShapeInferenceContextDim(ctx, shape, -1, last_dim);
if (!TF_DimensionHandleValueKnown(last_dim) ||
TF_DimensionHandleValue(last_dim) == divisor_val) {
TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status);
} else {
std::ostringstream err;
err << "Cannot bitcast due to shape. "
<< TF_DimensionHandleValue(last_dim) << " does not match "
<< divisor_val;
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
TF_DeleteDimensionHandle(last_dim);
}
} else if (input_type_size > output_type_size) {
// Input type size is larger than output type size.
size_t divisor_val = input_type_size / output_type_size;
TF_ShapeHandle* extension =
TF_ShapeInferenceContextVectorFromSize(ctx, divisor_val);
TF_ShapeInferenceContextConcatenateShapes(ctx, shape, extension, shape,
status);
TF_DeleteShapeHandle(extension);
}
}
static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_ShapeHandle* result = TF_NewShapeHandle();
TF_ShapeInferenceContextGetInput(ctx, 0, result, status);
if (TF_GetCode(status) == TF_OK &&
!TF_ShapeInferenceContextRankKnown(ctx, result)) {
TF_ShapeInferenceContextSetUnknownShape(ctx, status);
TF_DeleteShapeHandle(result);
return;
}
// Find the size of the input and output data types.
TF_DataType input_type;
TF_DataType output_type;
if (TF_GetCode(status) == TF_OK) {
TF_ShapeInferenceContext_GetAttrType(ctx, "T", &input_type, status);
}
if (TF_GetCode(status) == TF_OK) {
TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status);
}
size_t input_type_size;
size_t output_type_size;
if (TF_GetCode(status) == TF_OK) {
input_type_size = TF_DataTypeSize(input_type);
output_type_size = TF_DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
std::ostringstream err;
err << "Cannot bitcast type " << input_type << " to " << output_type
<< " because one of the type sizes is zero";
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
}
if (TF_GetCode(status) == TF_OK) {
ComputeNewShape(ctx, result, input_type_size, output_type_size, status);
}
if (TF_GetCode(status) == TF_OK) {
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
}
TF_DeleteShapeHandle(result);
}
void RegisterBitcastOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder = TF_NewOpDefinitionBuilder("Bitcast");
TF_OpDefinitionBuilderAddInput(op_builder, "input: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "output: type");
TF_OpDefinitionBuilderAddAttr(
op_builder,
"T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
"qint16, quint16, qint32}");
TF_OpDefinitionBuilderAddAttr(
op_builder,
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
"qint16, quint16, qint32}");
TF_OpDefinitionBuilderSetShapeInferenceFunction(op_builder,
&bitcast_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "Bitcast op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
static bool IsBitcastOpRegistered = []() {
if (SHOULD_REGISTER_OP("Bitcast")) {
RegisterBitcastOp();
}
return true;
}();

View File

@ -17,6 +17,17 @@ package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "srcs_no_runtime",
srcs = [
"framework/gradients.h",
"framework/ops.h",
"framework/scope.h",
"framework/scope_internal.h",
"//tensorflow/cc/saved_model:loader.h",
],
)
filegroup(
name = "srcs",
srcs = [
@ -485,6 +496,7 @@ tf_gen_op_wrappers_cc(
tf_gen_op_wrappers_cc(
name = "array_ops",
api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
extra_gen_deps = ["//tensorflow/c/kernels:bitcast_op_lib"],
op_lib_names = [
"array_ops",
],

View File

@ -1448,6 +1448,7 @@ cc_library(
":training_ops_op_lib",
":user_ops_op_lib",
":word2vec_ops",
"//tensorflow/c/kernels:bitcast_op_lib",
] + if_mkl([
":mkl_array_ops_op_lib",
":mkl_nn_ops_op_lib",
@ -1467,6 +1468,7 @@ cc_library(
":array_ops_op_lib",
":framework",
":lib",
"//tensorflow/c/kernels:bitcast_op_lib",
],
alwayslink = 1,
)
@ -2083,7 +2085,7 @@ cc_library(
filegroup(
name = "android_op_registrations_and_gradients",
srcs = glob(
srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob(
[
"ops/**/*.cc",
"ops/**/*.h",
@ -3140,6 +3142,20 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [
tf_cuda_library(
name = "core_cpu_base",
hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"],
copts = tf_copts(),
deps = [":core_cpu_base_no_ops"] + if_static([
":function_ops_op_lib",
":functional_grad",
":functional_ops_op_lib",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:required",
]),
alwayslink = 1,
)
tf_cuda_library(
name = "core_cpu_base_no_ops",
srcs = [
"common_runtime/eval_const_tensor.cc",
"common_runtime/scoped_allocator.cc",
@ -3148,11 +3164,10 @@ tf_cuda_library(
"common_runtime/graph_optimizer.h",
"graph/graph_constructor.cc", # Depends on common_runtime.
"graph/graph_def_builder_util.cc", # Depends on common_runtime.
"public/session.h",
"public/session_options.h",
"public/version.h",
] + CORE_CPU_BASE_HDRS,
hdrs = CORE_CPU_BASE_HDRS,
hdrs = CORE_CPU_BASE_HDRS + ["public/session.h"],
copts = tf_copts(),
deps = [
":graph",
@ -3165,14 +3180,8 @@ tf_cuda_library(
"@com_google_absl//absl/container:flat_hash_set",
"//third_party/eigen3",
] + if_static([
":function_ops_op_lib",
":functional_grad",
":functional_ops_op_lib",
"@com_google_absl//absl/algorithm:container",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:required",
]),
alwayslink = 1,
)
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
@ -3338,6 +3347,16 @@ tf_cuda_library(
] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(),
)
tf_cuda_library(
name = "core_cpu_lib_no_ops",
hdrs = CORE_CPU_LIB_HEADERS,
deps = [
":core_cpu_base_no_ops",
":proto_text",
"//tensorflow/core/grappler:grappler_item",
] + tf_protos_all() + tf_protos_grappler(),
)
tf_cuda_library(
name = "core_cpu_internal",
srcs = [
@ -4925,6 +4944,7 @@ tf_cc_test(
":test",
":test_main",
":testlib",
"//tensorflow/c/kernels:bitcast_op_lib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core/kernels:cwise_op",

View File

@ -178,7 +178,7 @@ cc_library(
hdrs = ["functions.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:core_cpu_base_no_ops",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",

View File

@ -6049,7 +6049,7 @@ filegroup(
"unpack_op.cc",
"variable_ops.cc",
"variable_ops.h",
"//tensorflow/c/kernels:android_all_ops",
"//tensorflow/c/kernels:android_all_op_kernels",
],
)
@ -6392,7 +6392,7 @@ ANDROID_TEXTUAL_HDRS = [
# registration.
filegroup(
name = "android_all_ops",
srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob(
srcs = ["//tensorflow/c/kernels:android_all_op_kernels"] + glob(
[
"*.cc",
"*.h",

View File

@ -2908,69 +2908,6 @@ REGISTER_OP("ExtractVolumePatches")
// --------------------------------------------------------------------------
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")
// All supported dtypes are listed here to include qint16, quint16, uint32,
// and uint64.
.Attr(
"T: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
"qint16, quint16, qint32}")
.Attr(
"type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
"uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, "
"qint16, quint16, qint32}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
if (!c->RankKnown(input)) {
// Input shape unknown.
return shape_inference::UnknownShape(c);
}
// Find the size of the input and output data types.
DataType input_type;
DataType output_type;
TF_RETURN_IF_ERROR(c->GetAttr("T", &input_type));
TF_RETURN_IF_ERROR(c->GetAttr("type", &output_type));
const int input_type_size = DataTypeSize(input_type);
const int output_type_size = DataTypeSize(output_type);
if (input_type_size == 0 || output_type_size == 0) {
return errors::InvalidArgument("Cannot bitcast types ",
DataTypeString(input_type), " to ",
DataTypeString(output_type),
" because "
"one of the type sizes is zero.");
}
ShapeHandle new_shape;
if (input_type_size == output_type_size) {
// No change in size.
new_shape = input;
} else if (input_type_size < output_type_size) {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 1, &new_shape));
int64 divisor_val = output_type_size / input_type_size;
DimensionHandle last_dim = c->Dim(new_shape, -1);
if (!c->ValueKnown(last_dim) || c->Value(last_dim) == divisor_val) {
TF_RETURN_IF_ERROR(c->Subshape(new_shape, 0, -1, &new_shape));
} else {
return errors::InvalidArgument("Cannot bitcast due to shape. ",
c->Value(last_dim), " does not match ",
divisor_val);
}
} else {
// Input type size is larger than output type size.
int64 divisor_val = input_type_size / output_type_size;
ShapeHandle extension = c->Vector(divisor_val);
TF_RETURN_IF_ERROR(c->Concatenate(input, extension, &new_shape));
}
c->set_output(0, new_shape);
return Status::OK();
});
REGISTER_OP("OneHot")
.Input("indices: TI")
.Input("depth: int32")

View File

@ -1883,6 +1883,10 @@ tf_gen_op_wrapper_private_py(
"//tensorflow/contrib/quantization:__pkg__",
"//tensorflow/python/kernel_tests:__pkg__",
],
deps = [
"//tensorflow/c/kernels:bitcast_op_lib",
"//tensorflow/core:array_ops_op_lib",
],
)
tf_gen_op_wrapper_private_py(