@@ -21,6 +21,7 @@ limitations under the License.
2121#include < cstddef>
2222#include < memory>
2323#include < string>
24+ #include < thread>
2425#include < vector>
2526
2627#include " tensorflow/c/c_api.h"
@@ -34,20 +35,34 @@ limitations under the License.
3435#include " tensorflow/core/lib/gtl/stl_util.h"
3536#include " tensorflow/core/platform/mutex.h"
3637#include " tensorflow/core/platform/thread_annotations.h"
38+ #include " tensorflow/core/public/version.h"
3739
3840struct TFE_ContextOptions {
3941 TF_SessionOptions session_options;
40- TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
42+ TFE_ContextDevicePlacementPolicy policy{
43+ TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
4144};
4245
4346struct TFE_Context {
44- explicit TFE_Context (TF_Session* s) : session(s) {}
45-
46- TFE_ContextDevicePlacementPolicy policy;
47+ explicit TFE_Context (const TFE_ContextOptions& opts, TF_Session* s)
48+ : policy(opts.policy),
49+ session(s),
50+ rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
51+ pflr(new tensorflow::ProcessFunctionLibraryRuntime(
52+ session->device_mgr, opts.session_options.options.env,
53+ TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
54+
55+ const TFE_ContextDevicePlacementPolicy policy;
56+
57+ // Note: we cannot use C++11 thread_local here as there is no concept of a
58+ // thread-local-object-local variable in C++11.
59+ tensorflow::mutex policy_map_mu;
60+ std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
61+ thread_local_policies GUARDED_BY (policy_map_mu);
4762
4863 // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
49- TF_Session* session;
50- tensorflow::Rendezvous* rendezvous;
64+ TF_Session* const session;
65+ tensorflow::Rendezvous* const rendezvous;
5166
5267 tensorflow::mutex functions_mu;
5368 tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY (functions_mu){
@@ -56,14 +71,14 @@ struct TFE_Context {
5671 // One FunctionLibraryRuntime per device.
5772 // func_libs[i] is the FunctionLibraryRuntime corresponding to
5873 // session->devices[i].
59- std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
74+ const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
6075
6176 tensorflow::mutex cache_mu;
6277 std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
6378 tensorflow::Fprint128Hasher>
6479 kernel_cache GUARDED_BY (cache_mu);
6580
66- tensorflow::FunctionLibraryRuntime* func_lib (tensorflow::Device* d) {
81+ tensorflow::FunctionLibraryRuntime* func_lib (tensorflow::Device* d) const {
6782 return pflr->GetFLR (d->name ());
6883 }
6984
0 commit comments