@@ -25,29 +25,40 @@ namespace cppflow {
2525 std::vector<tensor> operator ()(std::vector<std::tuple<std::string, tensor>> inputs, std::vector<std::string> outputs);
2626 tensor operator ()(const tensor& input);
2727
28+ ~model () = default ;
29+ model (const model &model) = default ;
30+ model (model &&model) = default ;
31+ model &operator =(const model &other) = default ;
32+ model &operator =(model &&other) = default ;
33+
2834 private:
2935
30- TF_Graph* graph;
31- TF_Session* session;
36+ std::shared_ptr< TF_Graph> graph;
37+ std::shared_ptr< TF_Session> session;
3238 };
3339}
3440
41+
3542namespace cppflow {
3643
3744 model::model (const std::string &filename) {
38- this ->graph = TF_NewGraph ();
45+ this ->graph = { TF_NewGraph (), TF_DeleteGraph} ;
3946
4047 // Create the session.
41- TF_SessionOptions* session_options = TF_NewSessionOptions ();
42- TF_Buffer* run_options = TF_NewBufferFromString (" " , 0 );
43- TF_Buffer* meta_graph = TF_NewBuffer ();
48+ std::unique_ptr<TF_SessionOptions, decltype (&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions (), TF_DeleteSessionOptions};
49+ std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString (" " , 0 ), TF_DeleteBuffer};
50+ std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer (), TF_DeleteBuffer};
51+
52+ auto session_deleter = [](TF_Session* sess) {
53+ TF_DeleteSession (sess, context::get_status ());
54+ status_check (context::get_status ());
55+ };
4456
4557 int tag_len = 1 ;
4658 const char * tag = " serve" ;
47- this ->session = TF_LoadSessionFromSavedModel (session_options, run_options, filename.c_str (), &tag, tag_len, graph, meta_graph, context::get_status ());
48- TF_DeleteSessionOptions (session_options);
49- TF_DeleteBuffer (run_options);
50- // TF_DeleteBuffer(meta_graph);
59+ this ->session = {TF_LoadSessionFromSavedModel (session_options.get (), run_options.get (), filename.c_str (),
60+ &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
61+ session_deleter};
5162
5263 status_check (context::get_status ());
5364 }
@@ -58,7 +69,7 @@ namespace cppflow {
5869 TF_Operation* oper;
5970
6071 // Iterate through the operations of a graph
61- while ((oper = TF_GraphNextOperation (this ->graph , &pos)) != nullptr ) {
72+ while ((oper = TF_GraphNextOperation (this ->graph . get () , &pos)) != nullptr ) {
6273 result.emplace_back (TF_OperationName (oper));
6374 }
6475 return result;
@@ -77,7 +88,7 @@ namespace cppflow {
7788
7889 // Operations
7990 const auto [op_name, op_idx] = parse_name (std::get<0 >(inputs[i]));
80- inp_ops[i].oper = TF_GraphOperationByName (this ->graph , op_name.c_str ());
91+ inp_ops[i].oper = TF_GraphOperationByName (this ->graph . get () , op_name.c_str ());
8192 inp_ops[i].index = op_idx;
8293
8394 if (!inp_ops[i].oper )
@@ -94,15 +105,15 @@ namespace cppflow {
94105 for (int i=0 ; i<outputs.size (); i++) {
95106
96107 const auto [op_name, op_idx] = parse_name (outputs[i]);
97- out_ops[i].oper = TF_GraphOperationByName (this ->graph , op_name.c_str ());
108+ out_ops[i].oper = TF_GraphOperationByName (this ->graph . get () , op_name.c_str ());
98109 out_ops[i].index = op_idx;
99110
100111 if (!out_ops[i].oper )
101112 throw std::runtime_error (" No operation named \" " + op_name + " \" exists" );
102113
103114 }
104115
105- TF_SessionRun (this ->session , NULL ,
116+ TF_SessionRun (this ->session . get () , NULL ,
106117 inp_ops.data (), inp_val.data (), inputs.size (),
107118 out_ops.data (), out_val.get (), outputs.size (),
108119 NULL , 0 ,NULL , context::get_status ());
0 commit comments