Move down_cast to platform/casts.h
PiperOrigin-RevId: 289736444 Change-Id: I93321c9130243b15d789bd4ec63588d54adc011e
This commit is contained in:
parent
99eb226655
commit
c8dc5d6d53
@ -1974,6 +1974,7 @@ cc_library(
|
|||||||
"//tensorflow/core/platform:abi",
|
"//tensorflow/core/platform:abi",
|
||||||
"//tensorflow/core/platform:base64",
|
"//tensorflow/core/platform:base64",
|
||||||
"//tensorflow/core/platform:blocking_counter",
|
"//tensorflow/core/platform:blocking_counter",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:coding",
|
"//tensorflow/core/platform:coding",
|
||||||
"//tensorflow/core/platform:context",
|
"//tensorflow/core/platform:context",
|
||||||
"//tensorflow/core/platform:cord",
|
"//tensorflow/core/platform:cord",
|
||||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/refcount.h"
|
#include "tensorflow/core/platform/refcount.h"
|
||||||
|
@ -124,20 +124,6 @@ class IteratorHandleOp : public OpKernel {
|
|||||||
// inconsistent capacities.
|
// inconsistent capacities.
|
||||||
Status VerifyResource(IteratorResource* resource);
|
Status VerifyResource(IteratorResource* resource);
|
||||||
|
|
||||||
template <typename To, typename From> // use like this: down_cast<T*>(foo);
|
|
||||||
static inline To down_cast(From* f) { // so we only accept pointers
|
|
||||||
static_assert(
|
|
||||||
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
|
|
||||||
"target type not derived from source type");
|
|
||||||
|
|
||||||
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
|
||||||
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
|
||||||
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
|
|
||||||
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
|
|
||||||
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
|
||||||
return static_cast<To>(f);
|
|
||||||
}
|
|
||||||
|
|
||||||
FunctionLibraryRuntime* CreatePrivateFLR(
|
FunctionLibraryRuntime* CreatePrivateFLR(
|
||||||
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
|
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
|
||||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -41,21 +42,6 @@ Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
|
|||||||
return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
|
return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename To, typename From> // use like this: down_cast<T*>(foo);
|
|
||||||
inline To down_cast(From* f) { // so we only accept pointers
|
|
||||||
static_assert(
|
|
||||||
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
|
|
||||||
"target type not derived from source type");
|
|
||||||
|
|
||||||
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
|
||||||
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
|
||||||
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
|
|
||||||
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
|
|
||||||
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
|
||||||
|
|
||||||
return static_cast<To>(f);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If "t" is a scalar of a supported type, returns t != 0 in "*v".
|
// If "t" is a scalar of a supported type, returns t != 0 in "*v".
|
||||||
Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
|
Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
|
||||||
if (t.size() != 1) {
|
if (t.size() != 1) {
|
||||||
|
@ -73,6 +73,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/variable_ops.h"
|
#include "tensorflow/core/kernels/variable_ops.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -104,7 +105,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
|
|||||||
} else if (ctx->op_device_context() != nullptr) {
|
} else if (ctx->op_device_context() != nullptr) {
|
||||||
// TODO(apassos): remove the down_cast by just returning Device* from
|
// TODO(apassos): remove the down_cast by just returning Device* from
|
||||||
// OpKernelContext
|
// OpKernelContext
|
||||||
Device* device = static_cast<Device*>(ctx->device());
|
Device* device = down_cast<Device*>(ctx->device());
|
||||||
ctx->op_device_context()->CopyTensorInSameDevice(
|
ctx->op_device_context()->CopyTensorInSameDevice(
|
||||||
t, device, output, [&n, &status](const Status& s) {
|
t, device, output, [&n, &status](const Status& s) {
|
||||||
status = s;
|
status = s;
|
||||||
|
@ -158,6 +158,14 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "casts",
|
||||||
|
hdrs = ["casts.h"],
|
||||||
|
deps = [
|
||||||
|
":platform",
|
||||||
|
] + tf_platform_deps("casts"),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cuda",
|
name = "cuda",
|
||||||
hdrs = ["cuda.h"],
|
hdrs = ["cuda.h"],
|
||||||
@ -1060,6 +1068,7 @@ filegroup(
|
|||||||
name = "lib_hdrs",
|
name = "lib_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
"abi.h",
|
"abi.h",
|
||||||
|
"casts.h",
|
||||||
"context.h",
|
"context.h",
|
||||||
"cpu_feature_guard.h",
|
"cpu_feature_guard.h",
|
||||||
"cpu_info.h",
|
"cpu_info.h",
|
||||||
@ -1254,6 +1263,7 @@ filegroup(
|
|||||||
"//tensorflow/core/platform:base64.h",
|
"//tensorflow/core/platform:base64.h",
|
||||||
"//tensorflow/core/platform:blocking_counter.h",
|
"//tensorflow/core/platform:blocking_counter.h",
|
||||||
"//tensorflow/core/platform:byte_order.h",
|
"//tensorflow/core/platform:byte_order.h",
|
||||||
|
"//tensorflow/core/platform:casts.h",
|
||||||
"//tensorflow/core/platform:coding.cc",
|
"//tensorflow/core/platform:coding.cc",
|
||||||
"//tensorflow/core/platform:coding.h",
|
"//tensorflow/core/platform:coding.h",
|
||||||
"//tensorflow/core/platform:context.h",
|
"//tensorflow/core/platform:context.h",
|
||||||
@ -1360,6 +1370,7 @@ filegroup(
|
|||||||
name = "legacy_srcs_no_runtime",
|
name = "legacy_srcs_no_runtime",
|
||||||
srcs = [
|
srcs = [
|
||||||
":legacy_srcs_common",
|
":legacy_srcs_common",
|
||||||
|
"//tensorflow/core/platform/default:casts.h",
|
||||||
"//tensorflow/core/platform/default:context.h",
|
"//tensorflow/core/platform/default:context.h",
|
||||||
"//tensorflow/core/platform/default:cord.h",
|
"//tensorflow/core/platform/default:cord.h",
|
||||||
"//tensorflow/core/platform/default:dynamic_annotations.h",
|
"//tensorflow/core/platform/default:dynamic_annotations.h",
|
||||||
|
31
tensorflow/core/platform/casts.h
Normal file
31
tensorflow/core/platform/casts.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_PLATFORM_CASTS_H_
|
||||||
|
#define TENSORFLOW_CORE_PLATFORM_CASTS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/platform.h"
|
||||||
|
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
#include "tensorflow/core/platform/google/casts.h"
|
||||||
|
#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
|
||||||
|
defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
|
||||||
|
defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
|
||||||
|
#include "tensorflow/core/platform/default/casts.h"
|
||||||
|
#else
|
||||||
|
#error Define the appropriate PLATFORM_<foo> macro for this platform
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_PLATFORM_CASTS_H_
|
@ -9,6 +9,16 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "casts",
|
||||||
|
hdrs = ["casts.h"],
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"no_oss",
|
||||||
|
"nobuilder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "context",
|
name = "context",
|
||||||
hdrs = ["//tensorflow/core/platform:context.h"],
|
hdrs = ["//tensorflow/core/platform:context.h"],
|
||||||
|
@ -541,6 +541,7 @@ def tf_proto_library(
|
|||||||
|
|
||||||
def tf_additional_lib_hdrs():
|
def tf_additional_lib_hdrs():
|
||||||
return [
|
return [
|
||||||
|
"//tensorflow/core/platform/default:casts.h",
|
||||||
"//tensorflow/core/platform/default:context.h",
|
"//tensorflow/core/platform/default:context.h",
|
||||||
"//tensorflow/core/platform/default:cord.h",
|
"//tensorflow/core/platform/default:cord.h",
|
||||||
"//tensorflow/core/platform/default:dynamic_annotations.h",
|
"//tensorflow/core/platform/default:dynamic_annotations.h",
|
||||||
|
92
tensorflow/core/platform/default/casts.h
Normal file
92
tensorflow/core/platform/default/casts.h
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
|
||||||
|
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
|
||||||
|
|
||||||
|
#include <assert.h> // for use with down_cast<>
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// An "upcast", i.e. a conversion from a pointer to an object to a pointer to a
|
||||||
|
// base subobject, always succeeds if the base is unambiguous and accessible,
|
||||||
|
// and so it's fine to use implicit_cast.
|
||||||
|
//
|
||||||
|
// A "downcast", i.e. a conversion from a pointer to an object to a pointer
|
||||||
|
// to a more-derived object that may contain the original object as a base
|
||||||
|
// subobject, cannot safely be done using static_cast, because you do not
|
||||||
|
// generally know whether the source object is really the base subobject of
|
||||||
|
// a containing, more-derived object of the target type. Thus, when you
|
||||||
|
// downcast in a polymorphic type hierarchy, you should use the following
|
||||||
|
// function template.
|
||||||
|
//
|
||||||
|
// In debug mode, we use dynamic_cast to double-check whether the downcast is
|
||||||
|
// legal (we die if it's not). In normal mode, we do the efficient static_cast
|
||||||
|
// instead. Thus, it's important to test in debug mode to make sure the cast is
|
||||||
|
// legal!
|
||||||
|
//
|
||||||
|
// This is the only place in the codebase we should use dynamic_cast.
|
||||||
|
// In particular, you should NOT use dynamic_cast for RTTI, e.g. for
|
||||||
|
// code like this:
|
||||||
|
// if (auto* p = dynamic_cast<Subclass1*>(foo)) HandleASubclass1Object(p);
|
||||||
|
// if (auto* p = dynamic_cast<Subclass2*>(foo)) HandleASubclass2Object(p);
|
||||||
|
// You should design the code some other way not to need this.
|
||||||
|
|
||||||
|
template <typename To, typename From> // use like this: down_cast<T*>(foo);
|
||||||
|
inline To down_cast(From* f) { // so we only accept pointers
|
||||||
|
static_assert(
|
||||||
|
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
|
||||||
|
"target type not derived from source type");
|
||||||
|
|
||||||
|
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
||||||
|
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||||
|
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
|
||||||
|
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
|
||||||
|
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||||
|
|
||||||
|
return static_cast<To>(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overload of down_cast for references. Use like this: down_cast<T&>(foo).
|
||||||
|
// The code is slightly convoluted because we're still using the pointer
|
||||||
|
// form of dynamic cast. (The reference form throws an exception if it
|
||||||
|
// fails.)
|
||||||
|
//
|
||||||
|
// There's no need for a special const overload either for the pointer
|
||||||
|
// or the reference form. If you call down_cast with a const T&, the
|
||||||
|
// compiler will just bind From to const T.
|
||||||
|
template <typename To, typename From>
|
||||||
|
inline To down_cast(From& f) {
|
||||||
|
static_assert(std::is_lvalue_reference<To>::value,
|
||||||
|
"target type not a reference");
|
||||||
|
static_assert(
|
||||||
|
(std::is_base_of<From, typename std::remove_reference<To>::type>::value),
|
||||||
|
"target type not derived from source type");
|
||||||
|
|
||||||
|
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
||||||
|
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||||
|
// RTTI: debug mode only
|
||||||
|
assert(dynamic_cast<typename std::remove_reference<To>::type*>(&f) !=
|
||||||
|
nullptr);
|
||||||
|
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||||
|
|
||||||
|
return static_cast<To>(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
|
Loading…
Reference in New Issue
Block a user