Move ImmediateExecutionTensorHandle::DataType() to AbstractTensorHandle.
Minor signature change of AddInputList. Implement AddInputList for GraphContext. This is needed for downstream implementation of gradients. PiperOrigin-RevId: 320513498 Change-Id: I16ee672c6d6f1f9b260b319c954158694c82a1db
This commit is contained in:
parent
78199b484c
commit
fdce31d861
@ -177,7 +177,9 @@ cc_library(
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -73,7 +73,8 @@ class AbstractOperation {
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandle* input) = 0;
|
||||
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
|
||||
virtual Status AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
@ -27,6 +29,9 @@ class AbstractTensorHandle {
|
||||
virtual ~AbstractTensorHandle() {}
|
||||
|
||||
public:
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
|
@ -49,6 +49,10 @@ class GraphTensor : public TracingTensorHandle {
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
void Release() override { delete this; }
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||
}
|
||||
TF_Output output_;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
@ -102,9 +106,18 @@ class GraphOperation : public TracingOperation {
|
||||
TF_AddInput(op_.get(), t->output_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"AddInputList has not been implemented yet.");
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
|
||||
std::vector<TF_Output> tf_outputs(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
|
||||
if (!t) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unable to cast input to GraphTensor");
|
||||
}
|
||||
tf_outputs[i] = t->output_;
|
||||
}
|
||||
TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
|
||||
return Status::OK();
|
||||
}
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override {
|
||||
|
@ -33,8 +33,6 @@ namespace tensorflow {
|
||||
// is needed a static_cast can be applied.
|
||||
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
public:
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual Status NumDims(int* num_dims) const = 0;
|
||||
// Returns number of elements across all dimensions.
|
||||
|
@ -93,6 +93,15 @@ class MlirTensor : public TracingTensorHandle {
|
||||
explicit MlirTensor(Value value)
|
||||
: TracingTensorHandle(kMlir), value_(value) {}
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
tensorflow::DataType type;
|
||||
Status s = ConvertScalarTypeToDataType(value_.getType(), &type);
|
||||
if (!s.ok()) {
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
void Release() override { delete this; }
|
||||
|
||||
Value getValue() { return value_; }
|
||||
@ -127,7 +136,7 @@ class MlirAbstractOp : public TracingOperation {
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
|
||||
@ -464,7 +473,8 @@ Status MlirAbstractOp::SetDeviceName(const char* name) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MlirAbstractOp::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
|
||||
Status MlirAbstractOp::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"AddInputList has not been implemented yet.");
|
||||
}
|
||||
|
@ -260,7 +260,8 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
|
||||
return MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
Status EagerOperation::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
|
||||
Status EagerOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
AddTensorHandle(h);
|
||||
|
@ -80,7 +80,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
Status SetAttrValue(const char* attr_name, const AttrValue& value);
|
||||
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override { return op_def_; };
|
||||
|
Loading…
Reference in New Issue
Block a user