| 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #include "tensorflow/core/graph/node_builder.h" |
| 17 | |
| 18 | #include <vector> |
| 19 | #include "tensorflow/core/framework/node_def_util.h" |
| 20 | #include "tensorflow/core/framework/versions.pb.h" |
| 21 | #include "tensorflow/core/lib/core/errors.h" |
| 22 | |
| 23 | namespace tensorflow { |
| 24 | |
| 25 | NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit) |
| 26 | : node(n), |
| 27 | error(false), |
| 28 | name(node != nullptr ? node->name() : (error = true, "" )), |
| 29 | index(i), |
| 30 | dt(SafeGetOutput(node, i, &error)) {} |
| 31 | |
| 32 | NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t) |
| 33 | : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {} |
| 34 | |
| 35 | NodeBuilder::NodeOut::NodeOut() |
| 36 | : node(nullptr), error(true), index(0), dt(DT_FLOAT) {} |
| 37 | |
| 38 | NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name, |
| 39 | const OpRegistryInterface* op_registry) |
| 40 | : def_builder_(name, op_name, op_registry) {} |
| 41 | |
| 42 | NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def) |
| 43 | : def_builder_(name, op_def) {} |
| 44 | |
| 45 | NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder) |
| 46 | : def_builder_(def_builder) {} |
| 47 | |
| 48 | NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { |
| 49 | inputs_.emplace_back(src_node, src_index); |
| 50 | DataType dt; |
| 51 | if (GetOutputType(src_node, src_index, &dt)) { |
| 52 | def_builder_.Input(src_node->name(), src_index, dt); |
| 53 | } |
| 54 | return *this; |
| 55 | } |
| 56 | |
| 57 | NodeBuilder& NodeBuilder::Input(NodeOut src) { |
| 58 | if (src.error) { |
| 59 | AddIndexError(src.node, src.index); |
| 60 | } else { |
| 61 | inputs_.emplace_back(src.node, src.index); |
| 62 | def_builder_.Input(src.name, src.index, src.dt); |
| 63 | } |
| 64 | return *this; |
| 65 | } |
| 66 | |
| 67 | NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) { |
| 68 | std::vector<NodeDefBuilder::NodeOut> srcs; |
| 69 | srcs.reserve(src_list.size()); |
| 70 | for (const auto& node_out : src_list) { |
| 71 | if (node_out.error) { |
| 72 | AddIndexError(node_out.node, node_out.index); |
| 73 | } else { |
| 74 | srcs.emplace_back(node_out.name, node_out.index, node_out.dt); |
| 75 | inputs_.emplace_back(node_out.node, node_out.index); |
| 76 | } |
| 77 | } |
| 78 | def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs)); |
| 79 | return *this; |
| 80 | } |
| 81 | |
| 82 | NodeBuilder& NodeBuilder::ControlInput(Node* src_node) { |
| 83 | control_inputs_.emplace_back(src_node); |
| 84 | def_builder_.ControlInput(src_node->name()); |
| 85 | return *this; |
| 86 | } |
| 87 | |
| 88 | NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) { |
| 89 | control_inputs_.insert(control_inputs_.end(), src_nodes.begin(), |
| 90 | src_nodes.end()); |
| 91 | for (const Node* src_node : src_nodes) { |
| 92 | def_builder_.ControlInput(src_node->name()); |
| 93 | } |
| 94 | return *this; |
| 95 | } |
| 96 | |
| 97 | NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { |
| 98 | def_builder_.Device(device_spec); |
| 99 | return *this; |
| 100 | } |
| 101 | |
| 102 | Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { |
| 103 | // In case of error, set *created_node to nullptr. |
| 104 | if (created_node != nullptr) *created_node = nullptr; |
| 105 | if (!errors_.empty()) { |
| 106 | return errors::InvalidArgument(str_util::Join(errors_, "\n" )); |
| 107 | } |
| 108 | |
| 109 | NodeDef node_def; |
| 110 | TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def)); |
| 111 | TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); |
| 112 | TF_RETURN_IF_ERROR( |
| 113 | CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); |
| 114 | Status status; |
| 115 | Node* node = graph->AddNode(node_def, &status); |
| 116 | if (!status.ok()) return status; |
| 117 | |
| 118 | for (size_t i = 0; i < inputs_.size(); ++i) { |
| 119 | if (inputs_[i].node != nullptr) { // Skip back edges. |
| 120 | graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); |
| 121 | } |
| 122 | } |
| 123 | for (Node* control_input : control_inputs_) { |
| 124 | graph->AddControlEdge(control_input, node); |
| 125 | } |
| 126 | if (created_node != nullptr) *created_node = node; |
| 127 | return Status::OK(); |
| 128 | } |
| 129 | |
| 130 | void NodeBuilder::AddIndexError(const Node* node, int i) { |
| 131 | if (node == nullptr) { |
| 132 | errors_.emplace_back( |
| 133 | strings::StrCat("Attempt to add nullptr Node to node with type " , |
| 134 | def_builder_.op_def().name())); |
| 135 | } else { |
| 136 | errors_.emplace_back( |
| 137 | strings::StrCat("Attempt to add output " , i, " of " , node->name(), |
| 138 | " not in range [0, " , node->num_outputs(), |
| 139 | ") to node with type " , def_builder_.op_def().name())); |
| 140 | } |
| 141 | } |
| 142 | |
| 143 | bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) { |
| 144 | bool error; |
| 145 | *dt = SafeGetOutput(node, i, &error); |
| 146 | if (error) AddIndexError(node, i); |
| 147 | return !error; |
| 148 | } |
| 149 | |
| 150 | } // namespace tensorflow |
| 151 | |