1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/node_builder.h"
17
18#include <vector>
19#include "tensorflow/core/framework/node_def_util.h"
20#include "tensorflow/core/framework/versions.pb.h"
21#include "tensorflow/core/lib/core/errors.h"
22
23namespace tensorflow {
24
25NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit)
26 : node(n),
27 error(false),
28 name(node != nullptr ? node->name() : (error = true, "")),
29 index(i),
30 dt(SafeGetOutput(node, i, &error)) {}
31
32NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
33 : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {}
34
35NodeBuilder::NodeOut::NodeOut()
36 : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
37
38NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
39 const OpRegistryInterface* op_registry)
40 : def_builder_(name, op_name, op_registry) {}
41
42NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
43 : def_builder_(name, op_def) {}
44
45NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
46 : def_builder_(def_builder) {}
47
48NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
49 inputs_.emplace_back(src_node, src_index);
50 DataType dt;
51 if (GetOutputType(src_node, src_index, &dt)) {
52 def_builder_.Input(src_node->name(), src_index, dt);
53 }
54 return *this;
55}
56
57NodeBuilder& NodeBuilder::Input(NodeOut src) {
58 if (src.error) {
59 AddIndexError(src.node, src.index);
60 } else {
61 inputs_.emplace_back(src.node, src.index);
62 def_builder_.Input(src.name, src.index, src.dt);
63 }
64 return *this;
65}
66
67NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
68 std::vector<NodeDefBuilder::NodeOut> srcs;
69 srcs.reserve(src_list.size());
70 for (const auto& node_out : src_list) {
71 if (node_out.error) {
72 AddIndexError(node_out.node, node_out.index);
73 } else {
74 srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
75 inputs_.emplace_back(node_out.node, node_out.index);
76 }
77 }
78 def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
79 return *this;
80}
81
82NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
83 control_inputs_.emplace_back(src_node);
84 def_builder_.ControlInput(src_node->name());
85 return *this;
86}
87
88NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
89 control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
90 src_nodes.end());
91 for (const Node* src_node : src_nodes) {
92 def_builder_.ControlInput(src_node->name());
93 }
94 return *this;
95}
96
97NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
98 def_builder_.Device(device_spec);
99 return *this;
100}
101
102Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
103 // In case of error, set *created_node to nullptr.
104 if (created_node != nullptr) *created_node = nullptr;
105 if (!errors_.empty()) {
106 return errors::InvalidArgument(str_util::Join(errors_, "\n"));
107 }
108
109 NodeDef node_def;
110 TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def));
111 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
112 TF_RETURN_IF_ERROR(
113 CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
114 Status status;
115 Node* node = graph->AddNode(node_def, &status);
116 if (!status.ok()) return status;
117
118 for (size_t i = 0; i < inputs_.size(); ++i) {
119 if (inputs_[i].node != nullptr) { // Skip back edges.
120 graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
121 }
122 }
123 for (Node* control_input : control_inputs_) {
124 graph->AddControlEdge(control_input, node);
125 }
126 if (created_node != nullptr) *created_node = node;
127 return Status::OK();
128}
129
130void NodeBuilder::AddIndexError(const Node* node, int i) {
131 if (node == nullptr) {
132 errors_.emplace_back(
133 strings::StrCat("Attempt to add nullptr Node to node with type ",
134 def_builder_.op_def().name()));
135 } else {
136 errors_.emplace_back(
137 strings::StrCat("Attempt to add output ", i, " of ", node->name(),
138 " not in range [0, ", node->num_outputs(),
139 ") to node with type ", def_builder_.op_def().name()));
140 }
141}
142
143bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
144 bool error;
145 *dt = SafeGetOutput(node, i, &error);
146 if (error) AddIndexError(node, i);
147 return !error;
148}
149
150} // namespace tensorflow
151