Move down_cast to platform/casts.h

PiperOrigin-RevId: 289736444
Change-Id: I93321c9130243b15d789bd4ec63588d54adc011e
This commit is contained in:
Gaurav Jain 2020-01-14 14:53:18 -08:00 committed by TensorFlower Gardener
parent 99eb226655
commit c8dc5d6d53
10 changed files with 150 additions and 30 deletions

View File

@ -1974,6 +1974,7 @@ cc_library(
"//tensorflow/core/platform:abi",
"//tensorflow/core/platform:base64",
"//tensorflow/core/platform:blocking_counter",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:coding",
"//tensorflow/core/platform:context",
"//tensorflow/core/platform:cord",

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/refcount.h"

View File

@ -124,20 +124,6 @@ class IteratorHandleOp : public OpKernel {
// inconsistent capacities.
Status VerifyResource(IteratorResource* resource);
template <typename To, typename From> // use like this: down_cast<T*>(foo);
static inline To down_cast(From* f) { // so we only accept pointers
static_assert(
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
"target type not derived from source type");
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
#if !defined(__GNUC__) || defined(__GXX_RTTI)
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
return static_cast<To>(f);
}
FunctionLibraryRuntime* CreatePrivateFLR(
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
@ -41,21 +42,6 @@ Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
}
template <typename To, typename From> // use like this: down_cast<T*>(foo);
inline To down_cast(From* f) { // so we only accept pointers
static_assert(
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
"target type not derived from source type");
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
#if !defined(__GNUC__) || defined(__GXX_RTTI)
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
return static_cast<To>(f);
}
// If "t" is a scalar of a supported type, returns t != 0 in "*v".
Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
if (t.size() != 1) {

View File

@ -73,6 +73,7 @@ limitations under the License.
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@ -104,7 +105,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
} else if (ctx->op_device_context() != nullptr) {
// TODO(apassos): remove the down_cast by just returning Device* from
// OpKernelContext
Device* device = static_cast<Device*>(ctx->device());
Device* device = down_cast<Device*>(ctx->device());
ctx->op_device_context()->CopyTensorInSameDevice(
t, device, output, [&n, &status](const Status& s) {
status = s;

View File

@ -158,6 +158,14 @@ cc_library(
],
)
cc_library(
name = "casts",
hdrs = ["casts.h"],
deps = [
":platform",
] + tf_platform_deps("casts"),
)
cc_library(
name = "cuda",
hdrs = ["cuda.h"],
@ -1060,6 +1068,7 @@ filegroup(
name = "lib_hdrs",
srcs = [
"abi.h",
"casts.h",
"context.h",
"cpu_feature_guard.h",
"cpu_info.h",
@ -1254,6 +1263,7 @@ filegroup(
"//tensorflow/core/platform:base64.h",
"//tensorflow/core/platform:blocking_counter.h",
"//tensorflow/core/platform:byte_order.h",
"//tensorflow/core/platform:casts.h",
"//tensorflow/core/platform:coding.cc",
"//tensorflow/core/platform:coding.h",
"//tensorflow/core/platform:context.h",
@ -1360,6 +1370,7 @@ filegroup(
name = "legacy_srcs_no_runtime",
srcs = [
":legacy_srcs_common",
"//tensorflow/core/platform/default:casts.h",
"//tensorflow/core/platform/default:context.h",
"//tensorflow/core/platform/default:cord.h",
"//tensorflow/core/platform/default:dynamic_annotations.h",

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_CASTS_H_
#define TENSORFLOW_CORE_PLATFORM_CASTS_H_
#include "tensorflow/core/platform/platform.h"
#if defined(PLATFORM_GOOGLE)
#include "tensorflow/core/platform/google/casts.h"
#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
#include "tensorflow/core/platform/default/casts.h"
#else
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
#endif // TENSORFLOW_CORE_PLATFORM_CASTS_H_

View File

@ -9,6 +9,16 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "casts",
hdrs = ["casts.h"],
tags = [
"manual",
"no_oss",
"nobuilder",
],
)
cc_library(
name = "context",
hdrs = ["//tensorflow/core/platform:context.h"],

View File

@ -541,6 +541,7 @@ def tf_proto_library(
def tf_additional_lib_hdrs():
return [
"//tensorflow/core/platform/default:casts.h",
"//tensorflow/core/platform/default:context.h",
"//tensorflow/core/platform/default:cord.h",
"//tensorflow/core/platform/default:dynamic_annotations.h",

View File

@ -0,0 +1,92 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
#include <assert.h> // for use with down_cast<>
#include <type_traits>
namespace tensorflow {
// An "upcast", i.e. a conversion from a pointer to an object to a pointer to a
// base subobject, always succeeds if the base is unambiguous and accessible,
// and so it's fine to use implicit_cast.
//
// A "downcast", i.e. a conversion from a pointer to an object to a pointer
// to a more-derived object that may contain the original object as a base
// subobject, cannot safely be done using static_cast, because you do not
// generally know whether the source object is really the base subobject of
// a containing, more-derived object of the target type. Thus, when you
// downcast in a polymorphic type hierarchy, you should use the following
// function template.
//
// In debug mode, we use dynamic_cast to double-check whether the downcast is
// legal (we die if it's not). In normal mode, we do the efficient static_cast
// instead. Thus, it's important to test in debug mode to make sure the cast is
// legal!
//
// This is the only place in the codebase we should use dynamic_cast.
// In particular, you should NOT use dynamic_cast for RTTI, e.g. for
// code like this:
// if (auto* p = dynamic_cast<Subclass1*>(foo)) HandleASubclass1Object(p);
// if (auto* p = dynamic_cast<Subclass2*>(foo)) HandleASubclass2Object(p);
// You should design the code some other way not to need this.
template <typename To, typename From> // use like this: down_cast<T*>(foo);
inline To down_cast(From* f) { // so we only accept pointers
static_assert(
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
"target type not derived from source type");
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
#if !defined(__GNUC__) || defined(__GXX_RTTI)
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
return static_cast<To>(f);
}
// Overload of down_cast for references. Use like this: down_cast<T&>(foo).
// The code is slightly convoluted because we're still using the pointer
// form of dynamic cast. (The reference form throws an exception if it
// fails.)
//
// There's no need for a special const overload either for the pointer
// or the reference form. If you call down_cast with a const T&, the
// compiler will just bind From to const T.
template <typename To, typename From>
inline To down_cast(From& f) {
static_assert(std::is_lvalue_reference<To>::value,
"target type not a reference");
static_assert(
(std::is_base_of<From, typename std::remove_reference<To>::type>::value),
"target type not derived from source type");
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
#if !defined(__GNUC__) || defined(__GXX_RTTI)
// RTTI: debug mode only
assert(dynamic_cast<typename std::remove_reference<To>::type*>(&f) !=
nullptr);
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
return static_cast<To>(f);
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_