diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index c3d1a779c23..437b4d555af 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -1482,7 +1482,7 @@ static Status ParseDeviceAssignmentAttr( ") are not valid for the current TPU topology"); } tpu::TpuCoreLocationExternal core_location = - tpu_topology.Core(x, y, z, kTensorCore, core); + tpu_topology.Core(kTensorCore, x, y, z, core); if (replica_assignment(x, y, z, core) != -1) { return errors::InvalidArgument("Duplicate coordinates (", x, ",", y, diff --git a/tensorflow/core/tpu/tpu_executor_init_fns.inc b/tensorflow/core/tpu/tpu_executor_init_fns.inc index 23fcf05f08c..560a7bb89e2 100644 --- a/tensorflow/core/tpu/tpu_executor_init_fns.inc +++ b/tensorflow/core/tpu/tpu_executor_init_fns.inc @@ -119,6 +119,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) { TFTPU_SET_FN(executor_fn, TpuTopology_ChipBounds_Y); TFTPU_SET_FN(executor_fn, TpuTopology_ChipBounds_Z); TFTPU_SET_FN(executor_fn, TpuTopology_HasChip); + TFTPU_SET_FN(executor_fn, TpuTopology_CoreForId); TFTPU_SET_FN(executor_fn, TpuTopology_Core); TFTPU_SET_FN(executor_fn, TpuTopology_NumCores); TFTPU_SET_FN(executor_fn, TpuTopology_Cores); diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h index 230d258fa2f..494604286c4 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -251,9 +251,12 @@ int TpuTopology_ChipBounds_X(SE_TpuTopology* tpu_topology); int TpuTopology_ChipBounds_Y(SE_TpuTopology* tpu_topology); int TpuTopology_ChipBounds_Z(SE_TpuTopology* tpu_topology); bool TpuTopology_HasChip(SE_TpuTopology* tpu_topology, int x, int y, int z); -SE_TpuTopology_Core* TpuTopology_Core(SE_TpuTopology* tpu_topology, int x, - int y, int z, - TpuCoreTypeEnum tpu_core_type, int index); +SE_TpuTopology_Core* TpuTopology_CoreForId(SE_TpuTopology* tpu_topology, + TpuCoreTypeEnum tpu_core_type, + int id); +SE_TpuTopology_Core* TpuTopology_Core(SE_TpuTopology* tpu_topology, + TpuCoreTypeEnum tpu_core_type, int x, + int y, int z, int index); int TpuTopology_NumCores(SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type); // 'cores' should be a preallocated array of size TpuTopology_NumCores. @@ -457,6 +460,7 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuTopology_ChipBounds_Y); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_ChipBounds_Z); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_HasChip); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_CoreForId); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_Core); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_NumCores); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_Cores); diff --git a/tensorflow/stream_executor/tpu/tpu_topology.cc b/tensorflow/stream_executor/tpu/tpu_topology.cc index 909f5bd9dac..c6ad7cb3475 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.cc +++ b/tensorflow/stream_executor/tpu/tpu_topology.cc @@ -91,11 +91,17 @@ bool TpuTopologyExternal::HasChip(int x, int y, int z) const { return tpu::ExecutorApiFn()->TpuTopology_HasChipFn(topology_, x, y, z); } -TpuCoreLocationExternal TpuTopologyExternal::Core(int x, int y, int z, - TpuCoreTypeEnum core_type, +TpuCoreLocationExternal TpuTopologyExternal::CoreForId( + TpuCoreTypeEnum core_type, int id) const { + return TpuCoreLocationExternal( + tpu::ExecutorApiFn()->TpuTopology_CoreForIdFn(topology_, core_type, id)); +} + +TpuCoreLocationExternal TpuTopologyExternal::Core(TpuCoreTypeEnum core_type, + int x, int y, int z, int index) const { return TpuCoreLocationExternal(tpu::ExecutorApiFn()->TpuTopology_CoreFn( - topology_, x, y, z, core_type, index)); + topology_, core_type, x, y, z, index)); } std::vector<TpuCoreLocationExternal> TpuTopologyExternal::cores( diff --git a/tensorflow/stream_executor/tpu/tpu_topology.h b/tensorflow/stream_executor/tpu/tpu_topology.h index 84e13a142b6..4a719960562 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.h +++ b/tensorflow/stream_executor/tpu/tpu_topology.h @@ -75,7 +75,8 @@ class TpuTopologyExternal { int32 ChipsPerHost() const; TpuTopologyChipBoundsExternal chip_bounds() const; bool HasChip(int x, int y, int z) const; - TpuCoreLocationExternal Core(int x, int y, int z, TpuCoreTypeEnum core_type, + TpuCoreLocationExternal CoreForId(TpuCoreTypeEnum core_type, int id) const; + TpuCoreLocationExternal Core(TpuCoreTypeEnum core_type, int x, int y, int z, int index) const; std::vector<TpuCoreLocationExternal> cores(TpuCoreTypeEnum core_type) const; int IdForHost(TpuDimensionsExternal host) const;