Moving ComputeArgumentShapes to tpu_compile_op_support.
PiperOrigin-RevId: 324064650 Change-Id: I057003860e7849b63351d855b4a9c0a166a10cad
This commit is contained in:
parent
c1336e9a40
commit
34bd3aaad4
@ -413,46 +413,6 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
|
||||||
Status TpuCompileOpKernelCommon::ComputeArgumentShapes(
|
|
||||||
const tpu::TPUCompileMetadataProto& metadata,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
|
||||||
std::vector<TensorShape>* arg_shapes) {
|
|
||||||
arg_shapes->resize(metadata.args_size());
|
|
||||||
int dynamic_shape_pos = 0;
|
|
||||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
|
||||||
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
|
||||||
// The XLA compiler determines the shape of each constant by inspecting the
|
|
||||||
// value of its corresponding host-memory tensor. As a result, we don't need
|
|
||||||
// to give the compiler graph-inferred shapes for constant arguments.
|
|
||||||
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
|
||||||
PartialTensorShape static_shape(arg.shape());
|
|
||||||
|
|
||||||
TensorShape& shape = (*arg_shapes)[i];
|
|
||||||
if (static_shape.IsFullyDefined()) {
|
|
||||||
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
|
||||||
} else {
|
|
||||||
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
|
||||||
<< "Too few dynamic shapes";
|
|
||||||
shape = dynamic_shapes[dynamic_shape_pos++];
|
|
||||||
if (!static_shape.IsCompatibleWith(shape)) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Mismatch between static and dynamic shape for argument. Static "
|
|
||||||
"shape: ",
|
|
||||||
static_shape.DebugString(),
|
|
||||||
"; dynamic shape: ", shape.DebugString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Checks we consumed all of the dynamic shapes.
|
|
||||||
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
|
||||||
<< "Too many dynamic shapes";
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function arguments and return values lose their device assignments, so we
|
// Function arguments and return values lose their device assignments, so we
|
||||||
// must recreate them.
|
// must recreate them.
|
||||||
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
|
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
|
||||||
|
@ -99,15 +99,6 @@ class TpuCompileOpKernelCommon {
|
|||||||
const std::vector<TensorShape>& arg_shapes,
|
const std::vector<TensorShape>& arg_shapes,
|
||||||
TpuProgramGroupInterface* tpu_program_group) = 0;
|
TpuProgramGroupInterface* tpu_program_group) = 0;
|
||||||
|
|
||||||
// Computes shapes for each argument. Uses both the static shape from the
|
|
||||||
// metadata, and the dynamic shapes where the static shape is not
|
|
||||||
// defined. There must be one dynamic_shape for each argument with a
|
|
||||||
// partially defined shape, in index order.
|
|
||||||
static Status ComputeArgumentShapes(
|
|
||||||
const tpu::TPUCompileMetadataProto& metadata,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
|
||||||
std::vector<TensorShape>* arg_shapes);
|
|
||||||
|
|
||||||
// Performs shape inference on `computation`, filling shape_info with operator
|
// Performs shape inference on `computation`, filling shape_info with operator
|
||||||
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
|
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
|
||||||
static Status RunShapeInferenceOnComputation(
|
static Status RunShapeInferenceOnComputation(
|
||||||
|
@ -540,5 +540,43 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
|
||||||
|
const std::vector<TensorShape>& dynamic_shapes,
|
||||||
|
std::vector<TensorShape>* arg_shapes) {
|
||||||
|
arg_shapes->resize(metadata.args_size());
|
||||||
|
int dynamic_shape_pos = 0;
|
||||||
|
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||||
|
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
||||||
|
// The XLA compiler determines the shape of each constant by inspecting the
|
||||||
|
// value of its corresponding host-memory tensor. As a result, we don't need
|
||||||
|
// to give the compiler graph-inferred shapes for constant arguments.
|
||||||
|
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
||||||
|
PartialTensorShape static_shape(arg.shape());
|
||||||
|
|
||||||
|
TensorShape& shape = (*arg_shapes)[i];
|
||||||
|
if (static_shape.IsFullyDefined()) {
|
||||||
|
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
||||||
|
<< "Too few dynamic shapes";
|
||||||
|
shape = dynamic_shapes[dynamic_shape_pos++];
|
||||||
|
if (!static_shape.IsCompatibleWith(shape)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Mismatch between static and dynamic shape for argument. Static "
|
||||||
|
"shape: ",
|
||||||
|
static_shape.DebugString(),
|
||||||
|
"; dynamic shape: ", shape.DebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Checks we consumed all of the dynamic shapes.
|
||||||
|
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
||||||
|
<< "Too many dynamic shapes";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -159,6 +159,14 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
|||||||
TPUCompileMetadataProto* metadata,
|
TPUCompileMetadataProto* metadata,
|
||||||
NameAttrList* function_name,
|
NameAttrList* function_name,
|
||||||
std::string* mlir_module);
|
std::string* mlir_module);
|
||||||
|
|
||||||
|
// Computes shapes for each argument. Uses both the static shape from the
|
||||||
|
// metadata, and the dynamic shapes where the static shape is not
|
||||||
|
// defined. There must be one dynamic_shape for each argument with a
|
||||||
|
// partially defined shape, in index order.
|
||||||
|
Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
|
||||||
|
const std::vector<TensorShape>& dynamic_shapes,
|
||||||
|
std::vector<TensorShape>* arg_shapes);
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user