[Grappler] Inline multi-device functions using common_runtime/function implementation
There are two helper functions in function_optimizer.cc that helps to deal with V1 graphs with missing control dependencies and "problematic" semantics:
1) AddStrictInputSemantics
Adds control edges from all data inputs to enforce "strict inputs" semantics when needed for correctness.
2) AddFrameForwardingControlEdge
Adds control edge from "Enter" node pass frame information.
PiperOrigin-RevId: 246194205
diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py
index 51443d2..fa46d73 100644
--- a/tensorflow/contrib/eager/python/evaluator.py
+++ b/tensorflow/contrib/eager/python/evaluator.py
@@ -165,8 +165,15 @@
self.__call__(example, *args, **kwargs)
return self.all_metric_results(summary_logdir)
# Graph construction
- call_op = self.__call__(
- dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs)
+ next_value = dataset_ops.make_one_shot_iterator(dataset).get_next()
+ # Function inlining destroys strict inputs semantics (function body might
+ # start execution before all inputs are ready). When iterator is exhausted
+ # and throws out of range error, function body might be partially executed.
+ # To prevent this we add an explicit control dependency from the 'get_next'.
+ with ops.control_dependencies([next_value]):
+ has_next_value = control_flow_ops.no_op(name="iterator_has_next")
+ with ops.control_dependencies([has_next_value]):
+ call_op = self.__call__(next_value, *args, **kwargs)
init_op = self.init_variables()
results_op = self.all_metric_results(summary_logdir)
return (init_op, call_op, results_op)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6c81391..fcb0d99 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3234,6 +3234,7 @@
":lib",
":proto_text",
":protos_all_cc",
+ "@com_google_absl//absl/strings",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/clusters:virtual_cluster",
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 86dd2b6..bde8958 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1501,6 +1501,7 @@
"disable_inlining=", true_false(disable_inlining),
", ignore_noinline=", true_false(ignore_noinline),
", override_device=", true_false(ignore_noinline),
+ ", initialize_empty_device=", true_false(initialize_empty_device),
", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
: "ControlOutputs");
@@ -1699,7 +1700,10 @@
for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
- if (options.override_device || ndef.device().empty()) {
+ if (options.override_device) {
+ ndef.set_device(caller->def().device());
+ }
+ if (options.initialize_empty_device && ndef.device().empty()) {
ndef.set_device(caller->def().device());
}
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index 450c974..3d071db 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -201,6 +201,13 @@
// If 'true' function inlining will override explicitly specified devices
// inside function body with the caller node device.
bool override_device = false;
+ // If 'true' function inlining will fill an empty device annotation inside
+ // function body with the caller node device.
+ // TODO(ezhulenev): Remove this flag. This is mostly legacy-compatibility
+ // mode. We should never explicitly define devices when we inline multi-device
+ // functions. However we do that in 'lower_function_call_op.cc' and
+ // 'function_optimizer' for now.
+ bool initialize_empty_device = false;
// Controls if we want to keep a node with the name as the function call node
// in a graph after function inlining.
KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 5d57e72..5290332 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -22,6 +22,7 @@
#include <utility>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
@@ -607,10 +608,12 @@
graph_->ToGraphDef(&item.graph);
// It's ok to skip invalid device annotations in Grappler.
- Status inferred_devices = item.InferDevicesFromGraph();
- if (!inferred_devices.ok()) {
- VLOG(3) << inferred_devices.error_message();
+ for (const Device* d : device_set_->devices()) {
+ Status added_device = item.AddDevice(d->name());
+ if (!added_device.ok()) VLOG(3) << added_device.error_message();
}
+ VLOG(3) << "Grappler available devices: "
+ << absl::StrJoin(item.devices(), ", ");
// TODO(b/114748242): Add a unit test to test this bug fix.
if (flib_def_) {
diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc
index aaa1755..4df335a 100644
--- a/tensorflow/core/common_runtime/lower_function_call_op.cc
+++ b/tensorflow/core/common_runtime/lower_function_call_op.cc
@@ -60,6 +60,7 @@
// Tensorflow 2.0 Eager mode, and it has control outputs to represent
// side-effects that must always execute (see `control_ret` in FunctionDef).
inline_options.override_device = false;
+ inline_options.initialize_empty_device = true;
inline_options.output_control_src = OutputControlSrc::kControlOutputs;
} else {
// Native function call (node.type_string() is the function name). These
diff --git a/tensorflow/core/common_runtime/lower_functional_ops.h b/tensorflow/core/common_runtime/lower_functional_ops.h
index 297f585..84d15a1 100644
--- a/tensorflow/core/common_runtime/lower_functional_ops.h
+++ b/tensorflow/core/common_runtime/lower_functional_ops.h
@@ -32,10 +32,8 @@
class LowerFunctionalOpsPass : public GraphOptimizationPass {
public:
LowerFunctionalOpsPass() = default;
- LowerFunctionalOpsPass(bool lower_function_calls,
- bool keep_lowered_nodes_fetchable)
- : lower_function_calls_(lower_function_calls),
- keep_lowered_nodes_fetchable_(keep_lowered_nodes_fetchable) {}
+ LowerFunctionalOpsPass(bool keep_lowered_nodes_fetchable)
+ : keep_lowered_nodes_fetchable_(keep_lowered_nodes_fetchable) {}
Status Run(const GraphOptimizationPassOptions& options) override;
@@ -45,10 +43,6 @@
"_lower_as_multi_device_function";
private:
- // TODO(ezhulenev): This is only required until Grappler function optimizer is
- // not migrated to use function inlining from common_runtime.
- bool lower_function_calls_ = true;
-
// If defined use the value to control if functional ops must be fetchable
// after lowering (we add IdentityN in place of all lowered nodes). If not
// defined, this option will be inferred automatically from the graph (in
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 5abe77f..cbef1c2 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -25,6 +25,15 @@
// Control flow info for a graph node.
struct ControlFlowInfo {
+ // 'frame' and 'parent_frame' are pointers to:
+ //
+ // a) One of the Enter nodes corresponding to the loop body, if the node
+ // executes inside a loop. If multiple tensors enter the while loop, it's
+ // undefined which Enter node will be used.
+ //
+ // b) SOURCE node (node.id() == Graph::kSourceId), if the node is not inside
+ // any of the while loops.
+
const Node* frame = nullptr; // frame of a node
const Node* parent_frame = nullptr; // parent frame of a node
string frame_name; // frame name of a node
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 401f4d7..955935f 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -149,8 +149,8 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_topology_view",
+ "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions",
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 321ba1f..630fcde 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -28,9 +28,9 @@
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+#include "tensorflow/core/common_runtime/lower_while_op.h"
#include "tensorflow/core/common_runtime/placer.h"
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -40,59 +40,25 @@
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/grappler/graph_topology_view.h"
+#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
-#include "tensorflow/core/grappler/utils/topological_sort.h"
-#include "tensorflow/core/grappler/utils/traversal.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow {
namespace grappler {
namespace {
-// WARNING: Code in this file implicitly assumes that function input and output
-// arguments are plain tensors (tensor lists are not supported). Function inputs
-// and outputs are always expanded to a single placeholder or output tensor.
-// With this assumption, the calling node's input/output ports always match
-// function input/output arguments.
-//
-// This is guaranteed by the implementation of MakeGrapplerFunctionItem.
+constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
// Mark functions that were created as a result of function specialization.
-constexpr char kGrapplerSpecializedFuncAttr[] = "_GrapplerSpecializedFunc";
-
-// Name of the attribute that defines the function for indirect function calls.
-constexpr char kFuncAttrName[] = "f";
-
-constexpr char kNoInlineAttr[] = "_noinline";
-
-// Name of the node that will have control edges from function input nodes, and
-// also used as a new destination for incoming control edges.
-constexpr char kInputsReadyNodeName[] = "inputs_ready";
-
-// Name of the node that will have control edges from function control output
-// nodes, and also used as a new source of outgoing control edges. This node
-// will guarantee that all side-effects inside function body will be executed
-// after function inlining.
-constexpr char kSideEffectsExecutedNodeName[] = "side_effects_executed";
-
-bool AttrIsTrue(const FunctionDef& func, const string& attr) {
- return func.attr().count(attr) != 0 && func.attr().at(attr).b();
-}
-
-bool MarkedSpecialized(const FunctionDef& func) {
- return AttrIsTrue(func, kGrapplerSpecializedFuncAttr);
-}
-
-bool MarkedNoInline(const FunctionDef& func) {
- return AttrIsTrue(func, kNoInlineAttr);
-}
+constexpr const char* const kGrapplerSpecializedFuncAttr =
+ "_GrapplerSpecializedFunc";
// There are two ways of calling a Tensorflow function:
//
@@ -114,7 +80,7 @@
return false;
}
- auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
+ auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
return func_attr != nullptr && func_attr->has_func() &&
func_attr->func().name() == func.signature().name();
}
@@ -125,7 +91,7 @@
return AttrSlice(func_node);
} else if (IsIndirectFunctionCall(func, func_node)) {
- auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
+ auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
return AttrSlice(&func_attr->func().attr());
} else {
@@ -293,52 +259,21 @@
const FunctionLibraryDefinition& function_library() const {
return function_library_;
}
-
- FunctionLibraryDefinition* mutable_function_library() {
- return &function_library_;
- }
-
- FunctionLibraryRuntime* mutable_function_library_runtime() {
- InitializeFunctionLibraryRuntime();
- return flr_;
- }
+ FunctionLibraryDefinition& function_library() { return function_library_; }
const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const {
return tensor_mapping_;
}
- const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
- const {
- return control_overrides_;
- }
-
const GraphView& graph_view() const { return graph_view_; }
- const DeviceSet* devices() const {
- // Create fake devices lazily only if we need a DeviceSet.
- if (available_devices_.empty() && !item_->devices().empty()) {
- for (const string& name : item_->devices()) {
- auto device = absl::make_unique<FakeDevice>(name);
- available_device_set_.AddDevice(device.get());
- available_devices_.push_back(std::move(device));
- }
- }
- return &available_device_set_;
- }
-
bool IsFetchNode(const string& node_name) const {
return absl::c_any_of(item_->fetch, [&](const string& fetch) {
return ParseTensorName(fetch).node() == node_name;
});
}
- bool IsKeepOp(const string& node_name) const {
- return absl::c_any_of(item_->keep_ops, [&](const string& keep_node) {
- return keep_node == node_name;
- });
- }
-
bool IsTrulyConst(const string& name) const {
return TrulyConstNode(name) != nullptr;
}
@@ -382,17 +317,6 @@
}
}
- void AddControlOverrides(const NodeDef& func_node,
- const std::vector<string>& control_overrides) {
- VLOG(4) << "Add control overrides: from=" << func_node.name() << " to: ["
- << absl::StrJoin(control_overrides, ", ") << "]";
-
- control_overrides_[func_node.name()].reserve(control_overrides.size());
- for (const string& control_override : control_overrides) {
- control_overrides_[func_node.name()].push_back(control_override);
- }
- }
-
private:
static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
const GrapplerItem& item, const GraphDef& graph) {
@@ -411,39 +335,12 @@
return const_nodes;
}
- void InitializeFunctionLibraryRuntime() {
- if (!flr_) {
- Env* env = Env::Default();
- std::vector<std::unique_ptr<Device>> devices;
- devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
- device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
- OptimizerOptions optimizer_opts;
- optimizer_opts.set_do_function_inlining(true);
- process_flr_.reset(new ProcessFunctionLibraryRuntime(
- device_mgr_.get(), env, item_->graph.versions().producer(),
- &function_library_, optimizer_opts));
- flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
- }
- }
-
const GrapplerItem* item_; // must outlive this object
RewriterConfig::Toggle opt_level_;
// Function library constructed from current graph.
FunctionLibraryDefinition function_library_;
- // These fields initialized lazily only if needed.
- std::unique_ptr<DeviceMgr> device_mgr_;
- std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
- FunctionLibraryRuntime* flr_ = nullptr;
-
- // List of available `FakedDevices` (lazily initialized, see devices()).
- mutable std::vector<std::unique_ptr<Device>> available_devices_;
-
- // DeviceSet of fake devices (`FakeDevice`) constructed from
- // item_.devices() (lazily initialized).
- mutable DeviceSet available_device_set_;
-
// Nodes that are Const and not in feed.
absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
// Specialized functions.
@@ -451,24 +348,15 @@
const FunctionSpecialization>
specialized_functions_;
- // After function inlining and specialization, the optimized graph might be in
- // invalid state, nodes can read from non-existing function call nodes that
- // were inlined, or they can read from output index that is no longer valid
- // after unused outputs pruning.
+ // After function specialization, the optimized graph might be in invalid
+ // state, nodes can read from output index that is no longer valid after
+ // unused outputs pruning.
//
// Tensor mapping that has to be applied to the graph after all functions
// optimizations (invalidated tensor id -> optimized graph tensor id).
absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
tensor_mapping_;
- // When we inline a function into the optimized graph, we no longer have the
- // function call node to anchor control dependencies. Instead we must expand
- // each function call control output edge into multiple control dependencies
- // to all side-effectful ops inside the function body.
- //
- // Invalidated function call node name -> Inlined side-effectful nodes
- absl::flat_hash_map<string, std::vector<string>> control_overrides_;
-
// Use graph view to find active outputs of the function caller nodes.
GraphView graph_view_;
@@ -595,11 +483,10 @@
// Keep only non-const inputs.
std::vector<string> keep_inputs;
const auto& inputs = specialized_func_node->input();
- std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(keep_inputs),
- [&](const string& input) {
- return specialization.const_inputs.find(input) ==
- specialization.const_inputs.end();
- });
+ absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
+ [&](const string& input) {
+ return !specialization.const_inputs.contains(input);
+ });
specialized_func_node->clear_input();
for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
@@ -613,7 +500,7 @@
}
for (const string& ctrl : specialization.control_deps) {
- if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
+ if (!existing_control_deps.contains(ctrl)) {
VLOG(3) << "Forward control dependency: input=" << ctrl;
specialized_func_node->add_input(ctrl);
}
@@ -641,8 +528,7 @@
const string& input = func_node.input(i);
if (IsControlInput(input)) break;
- if (specialization.const_inputs.find(input) ==
- specialization.const_inputs.end()) {
+ if (!specialization.const_inputs.contains(input)) {
DataType dt = tin->list().type(i);
(*attr)["Tin"].mutable_list()->add_type(dt);
}
@@ -666,8 +552,7 @@
// Keep output types of active outputs only.
for (int i = 0; i < tout->list().type_size(); ++i) {
- if (specialization.active_outputs.find(i) !=
- specialization.active_outputs.end()) {
+ if (specialization.active_outputs.contains(i)) {
DataType dt = tout->list().type(i);
(*attr)["Tout"].mutable_list()->add_type(dt);
}
@@ -683,7 +568,7 @@
} else if (IsIndirectFunctionCall(func, func_node)) {
auto* attr = specialized_func_node->mutable_attr();
- (*attr)[kFuncAttrName].mutable_func()->set_name(specialized_func_name);
+ (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
} else {
return errors::InvalidArgument("Unknown function call site");
@@ -853,8 +738,7 @@
(*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
// Add specialized function to the library.
- TF_RETURN_IF_ERROR(
- ctx->mutable_function_library()->AddFunctionDef(specialized_func));
+ TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
// Add a function call node for the specialized function.
NodeDef* specialized_func_node = optimized_graph->add_node();
@@ -881,9 +765,21 @@
// 2) Inline function calls.
// 3) Convert Graph back to the GraphDef.
+constexpr const char* const kLowerUsingSwitchMergeAttr =
+ LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
+constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
+ LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
+
using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
+// Checks if boolean attribute is defined and it's value is 'true'.
+bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
+ bool match;
+ Status s = GetNodeAttr(n->attrs(), attr_name, &match);
+ return s.ok() && match;
+}
+
// Checks if string attribute is defined and it's not empty.
bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
string match;
@@ -891,6 +787,14 @@
return s.ok() && !match.empty();
}
+bool LowerUsingSwitchMergeIsOn(const Node* n) {
+ return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
+}
+
+bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
+ return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
+}
+
bool MarkedForTpuCompilation(const Node* n) {
static constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
return CheckStringAttr(n, kTpuReplicateAttr);
@@ -977,10 +881,77 @@
return Status::OK();
}
+// Validates that no dead tensor can reach function output.
+Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
+ const FunctionBody& fbody) {
+ absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
+ fbody.ret_nodes.end()};
+
+ // Find all nodes that can produce dead tensors.
+ std::vector<const Node*> dead_tensor_sources;
+ for (const Node* n : fbody.graph->nodes()) {
+ if (n->IsSwitch()) {
+ VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
+ dead_tensor_sources.push_back(n);
+ continue;
+ }
+
+ // Native function call can also produce dead tensors if the function body
+ // has mergeless switches.
+ const FunctionDef* fdef = flib_def.Find(n->type_string());
+ if (fdef != nullptr) {
+ std::unique_ptr<FunctionBody> nested_fbody;
+
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
+ &flib_def, &nested_fbody));
+
+ if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
+ VLOG(4) << "Add dead tensors source. Function call: " << func.name()
+ << " node=" << n->name();
+ dead_tensor_sources.push_back(n);
+ }
+ }
+ }
+
+ for (const Node* dead_tensor_source : dead_tensor_sources) {
+ bool has_dead_output = false;
+
+ const auto is_output_node = [&](const Node* n) -> void {
+ const auto it = output_nodes.find(n);
+ if (it != output_nodes.end()) {
+ VLOG(4) << "Found a path to output node from dead tensor source: "
+ << dead_tensor_source->name() << " ---> " << (*it)->name();
+ has_dead_output = true;
+ }
+ };
+
+ // Stop DFS traversal at a Merge node or if already found a dead output.
+ const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
+ return !edge.src()->IsMerge() || has_dead_output;
+ };
+
+ DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
+ /*leave=*/{}, NodeComparatorName{},
+ /*edge_filter=*/stop_traversal);
+
+ if (has_dead_output) {
+ return errors::Internal(
+ "Can't inline a function with dead outputs. Dead tensor source: ",
+ SummarizeNode(*dead_tensor_source));
+ }
+ }
+
+ return Status::OK();
+}
+
// Makes an instance of FunctionBody for inlining from a Node.
Status MakeFunctionBodyForInlining(const Node& node,
const FunctionLibraryDefinition& flib_def,
std::unique_ptr<FunctionBody>* fbody) {
+ VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
+
// Finds a FunctionDef in a library and verifies that it exists.
const auto find_fdef = [&flib_def, &node](
const string& name,
@@ -997,8 +968,7 @@
// deprecated for a while, but we still support for compatibility reasons.
if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
NameAttrList func;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node.attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
const string grad = flib_def.FindGradient(func.name());
@@ -1029,7 +999,7 @@
grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
} else {
- // Compute numerical gradient for a function by traversing its body.
+ // Build a gradient graph from the function body.
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
@@ -1047,24 +1017,112 @@
TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
VLOG(4) << "Instantiate a function call: function=" << func.name();
- TF_RETURN_IF_ERROR(
- FunctionDefToBodyHelper(*fdef, node.attrs(), &flib_def, fbody));
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
+ &flib_def, fbody));
}
return Status::OK();
}
+// Adds a control edges from each data input to the 'caller' to enforce strict
+// inputs semantics (all inputs are ready and alive). This is required when:
+//
+// 1) The function takes resources as inputs, and it doesn't have incoming
+// control edges. In Tensorflow v2 context (eager mode) this should never
+// happen, because automatic control dependencies tracking will add a
+// control edge from the last op touching the resource. However such graphs
+// might be produced by legacy v1 code without automatic dependency
+// tracking. In this case strict function call semantics is required for
+// enforcing side effects execution order.
+//
+// 2) One of the inputs is consuming Enter[is_constant=true] node, in which
+// case it will be always alive, and potentially can lead to partial
+// function execution after the last loop execution.
+//
+// Both of these cases would be considered illegal by construction in Tensorflow
+// V2, however we have to guarantee that graphs constructed with Tensorflow V1
+// will produce correct results.
+void AddStrictInputSemantics(Node* caller, Graph* g) {
+ const bool has_incoming_control_edges =
+ absl::c_any_of(caller->in_edges(),
+ [](const Edge* edge) { return edge->IsControlEdge(); });
+
+ const bool has_resource_input =
+ absl::c_any_of(caller->input_types(),
+ [](const DataType dtype) { return dtype == DT_RESOURCE; });
+
+ const bool has_constant_enter_input =
+ absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
+ Node* src = edge->src();
+ return src->IsEnter() && CheckBoolAttr(src, "is_constant");
+ });
+
+ const bool requires_strict_semantics =
+ (!has_incoming_control_edges && has_resource_input) || // Case #1
+ (has_constant_enter_input); // Case #2
+ if (!requires_strict_semantics) return;
+
+ std::vector<const Node*> data_inputs;
+ data_inputs.reserve(caller->in_edges().size());
+
+ for (const Edge* edge : caller->in_edges()) {
+ if (edge->IsControlEdge()) continue;
+ data_inputs.push_back(edge->src());
+ }
+
+ VLOG(3) << "Add control edges from all data inputs to enforce strict "
+ "semantics with regard to function inputs";
+ for (const Node* node : data_inputs) {
+ g->AddControlEdge(g->FindNodeId(node->id()), caller);
+ }
+}
+
+// Adds a control edge from a frame node if the 'caller' is executing inside a
+// While loop (see control_flow.h for the 'frame' node explanation).
+void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
+ Node* caller, Graph* g) {
+ // All nodes added to the graph by v2 control flow lowering and function
+ // inlining are guaranteed to have control edges to nested function calls.
+ if (caller->id() >= info.size()) return;
+
+ // Check if a lowered node is executing inside a while loop.
+ const Node* frame = info[caller->id()].frame;
+ const bool is_in_while_loop = frame->id() != Graph::kSourceId;
+ if (!is_in_while_loop) return;
+
+ // Check if a node already has an incoming control edge. All incoming edges
+ // must be from the same execution frame (executor.cc invariant), so if we
+ // already have an incoming control edge, it's guaranteed that it will "carry"
+ // the same frame as all regular inputs.
+ const bool has_incoming_control_edges =
+ absl::c_any_of(caller->in_edges(),
+ [](const Edge* edge) { return edge->IsControlEdge(); });
+ if (has_incoming_control_edges) return;
+
+ VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
+ << " to=" << caller->name();
+ g->AddControlEdge(g->FindNodeId(frame->id()), caller);
+}
+
+// Inlines all function calls that are safe for inlining into the main graph.
+// Also lowers control flow V2 ops (functional If/While) into the V1 low level
+// ops (Switch/Merge/...).
+//
+// Runs a placer after inlining, to keep all nodes in a graph placed.
Status InlineFunctionCalls(const GrapplerItem& item,
- const FunctionLibraryDefinition& flib_def,
- const GraphDef& input_graph,
- std::unordered_set<string>* skip_nodes,
+ const RewriterConfig::Toggle opt_level,
GraphDef* output_graph) {
- VLOG(2) << "Inline function calls";
- Graph graph(flib_def);
+ bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
+ VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
+ << " (aggessive_mode=" << is_aggressive << ")";
+
+ FunctionLibraryDefinition flib_def =
+ FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(flib_def);
GraphConstructorOptions graph_constructor_options;
- TF_RETURN_IF_ERROR(
- ConvertGraphDefToGraph(graph_constructor_options, input_graph, &graph));
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
+ item.graph, graph.get()));
using NodeNames = absl::flat_hash_set<absl::string_view>;
NodeNames fetch_nodes;
@@ -1074,20 +1132,42 @@
}
NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
+ std::vector<string> inlined_function_names;
+
+ // If a function call is inside a While loop, it must have an incoming control
+ // edge, because it will be used to pass execution frame into the function
+ // body. All nodes without inputs in the function body (e.g. Const and NoOp)
+ // will be added an extra control edge from the 'input_control_node'.
+ std::vector<ControlFlowInfo> control_flow_info;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
+
// Function inlining always adds new nodes to the end of the list, so we keep
// iterating until we are out of nodes.
- for (int i = 2; i < graph.num_node_ids(); ++i) {
- Node* n = graph.FindNodeId(i);
-
+ for (int i = 2; i < graph->num_node_ids(); ++i) {
+ Node* n = graph->FindNodeId(i);
if (n == nullptr) continue; // deleted node
- if (MarkedForTpuCompilation(n)) continue;
- if (MarkedForXlaCompilation(n)) continue;
+
+ // Special case for lowering functional control flow ops. We do not rely on
+ // LowerFunctionOpsPass because in Grappler we have to be more restrictive
+ // about what type of function calls we are allowed to inline.
+ if (LowerUsingSwitchMergeIsOn(n)) {
+ VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
+ AddStrictInputSemantics(n, graph.get());
+ AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
+
+ if (n->type_string() == "If") {
+ TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
+ } else if (n->type_string() == "While") {
+ TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
+ }
+ continue;
+ }
// Skip nodes that are not function calls.
if (!IsFunctionCall(flib_def, *n)) continue;
-
- // TODO(ezhulenev): Inline multi-device functions.
- if (n->IsPartitionedCall()) continue;
+ // Skip function calls that we plan to compile later.
+ if (MarkedForTpuCompilation(n)) continue;
+ if (MarkedForXlaCompilation(n)) continue;
// Function body that we will inline into the main graph. It can be a
// function instantiation, or a gradient function instantiated from
@@ -1096,8 +1176,30 @@
TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
InlineFunctionBodyOptions inline_options;
- inline_options.override_device = true;
- inline_options.output_control_src = OutputControlSource::kDataOutputs;
+ // Ignore '_noinline' flag in aggressive mode.
+ inline_options.ignore_noinline = is_aggressive;
+
+ // Function calls created after inlining If/While ops are always inlined as
+ // multi-device functions and are not required to pass additional Grappler
+ // validations (side effects execution validation below).
+ bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
+
+ // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
+ // functions:
+ // a) Function can be multi-device, and we can't override device placements.
+ // b) Automatic control dependencies tracking guarantees that all function
+ // side-effectful nodes will have a path to one of the control outputs.
+ // Control outputs and control edges between side-effectful (stateful)
+ // nodes are used to explicitly mark the nodes that must execute, and to
+ // define their execution order.
+ if (n->IsPartitionedCall() || force_inline_as_multi_device) {
+ inline_options.override_device = false;
+ inline_options.initialize_empty_device = true;
+ inline_options.output_control_src = OutputControlSource::kControlOutputs;
+ } else {
+ inline_options.override_device = true;
+ inline_options.output_control_src = OutputControlSource::kDataOutputs;
+ }
if (fetch_nodes.contains(n->name())) {
inline_options.keep_caller_node = KeepCallerNode::kFetchable;
@@ -1121,802 +1223,125 @@
can_inline_function_call = ValidateSideEffectsExecution(
*fbody, inline_options.output_control_src,
has_outgoing_control_edges);
+
+ if (!can_inline_function_call.ok() &&
+ (is_aggressive || force_inline_as_multi_device)) {
+ VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
+ can_inline_function_call = Status::OK();
+ }
+ }
+ if (can_inline_function_call.ok()) {
+ can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
}
if (can_inline_function_call.ok()) {
- VLOG(2) << "Inline function call: " << SummarizeNode(*n);
- TF_RETURN_IF_ERROR(InlineFunctionBody(graph.flib_def(), &graph, n,
+ VLOG(2) << "Inline function call node: " << n->name();
+ AddStrictInputSemantics(n, graph.get());
+ AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
+
+ TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
fbody.get(), inline_options));
+ inlined_function_names.push_back(fbody->fdef.signature().name());
+
} else {
VLOG(2) << "Failed to inline function call node: "
- << can_inline_function_call.error_message() << "; "
- << SummarizeNode(*n);
+ << can_inline_function_call.error_message();
}
}
- graph.ToGraphDef(output_graph);
- return Status::OK();
-}
-
-// -------------------------------------------------------------------------- //
-// Inline indirect functions calls (aka PartitionedCallOp).
-//
-// When we inline indirect function calls, we instantiate the function body from
-// its FunctionDef and caller node attributes, and embed the instantiated graph
-// into the "main graph".
-//
-// In contrast to direct function calls, `PartitionedCallOp` has automatic
-// dependency tracking via input/output control edges, and we relax some of the
-// constraints that we have for direct function call inlining.
-//
-// Automatic control dependency rules:
-//
-// 1) "When a `PartitionedCallOp` function has a resource (DT_RESOURCE data
-// type) input argument it "captures" the mutable resource. This is
-// implemented by automatically adding a incoming control edge from the
-// previous side-effectful op touching that resource, and an outgoing control
-// edge to the next side-effectful op using the same resource. This
-// serializes the mutations of the resource to make graph execution
-// deterministic.
-//
-// 2) All stateful ops inside a function body are guaranteed to execute in
-// program order, this is achieved by adding control edges between stateful
-// ops at graph construction time.
-//
-// 3) Furthermore, all ops accepting the same resource as an input are
-// guaranteed to run in program order. This is also done by adding control
-// edges at graph construction time. The last op touching the resource
-// will have an outgoing control edge to all function return nodes, which
-// will guarantee that all side effects to the resource will happen before
-// function completion.
-//
-// Function call inlining must preserve side effect visibility:
-//
-// 1) All side effects to the captured resources, that happened before function
-// call must be visible to the function body nodes using that resources.
-// 2) All side effects to the captured resources, that happened inside function
-// body, must be visible to every op/function using that resource after the
-// function call completed.
-//
-// To guarantee that these properties are preserved after inlining we:
-//
-// 1) Create "input_control" NoOp. Function call node incoming control edges
-// will be forwarded *to* this node. Function inputs (Identity nodes) will
-// have a control edge *from* this node. If function has no inputs, by
-// construction it must have nodes without inputs in the function body, and
-// in this case these nodes will have a control edge *from* this node.
-
-// 2) Create "output_control" NoOp. All nodes that have incoming control edge
-// *from* the function call node, will be forwarded to this node. Function
-// outputs (Identity nodes) will have a control edge *to* this node. This
-// will guarantee that nodes that have control dependency on the function
-// call, will observe all side-effects (guaranteed by graph construction with
-// automatic control dependencies tracking).
-//
-// If after function instantiation we find a stateful or a dataset op inside
-// the function body, that is not reachable from any of the function outputs (or
-// if the function has no outputs), we do not inline it, because we can't
-// guarantee that these nodes will be executed in correct order (or executed at
-// all) after inlining.
-//
-// We do not try to add any extra control edges to make sure that all
-// side-effectful nodes will be executed, that should be handled at graph
-// construction time.
-
-struct MaybeDeadOutput {
- const NodeDef* dead_tensor_src;
- const NodeDef* output_node_dst;
-};
-
-// Finds all function outputs that might return a dead tensor. This can happen
-// if there is no `Merge` node on the path from the `Switch` node, to the
-// function output.
-Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx,
- const GrapplerFunctionItem& item,
- std::vector<MaybeDeadOutput>* maybe_dead) {
- VLOG(3) << "Find function outputs that might return dead tensors: item.id="
- << item.id;
- DCHECK(maybe_dead->empty()) << "Input argument must be an empty vector";
-
- std::vector<const NodeDef*> dead_tensor_srcs;
- for (const NodeDef& node : item.graph.node()) {
- if (IsSwitch(node)) {
- VLOG(4) << "Add dead tensors source. Switch node: " << node.name();
- dead_tensor_srcs.push_back(&node);
- continue;
- }
-
- // Regular (aka 'direct') function call can also produce dead tensors if
- // the function body has mergeless switches.
- const FunctionDef* func = ctx.function_library().Find(node.op());
- if (func != nullptr) {
- GrapplerFunctionItem func_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
- *func, FunctionInstantiationAttributes(*func, node),
- ctx.function_library(), ctx.graph_version(), &func_item));
-
- std::vector<MaybeDeadOutput> func_dead_outputs;
- TF_RETURN_IF_ERROR(MaybeDeadOutputs(ctx, func_item, &func_dead_outputs));
-
- if (!func_dead_outputs.empty()) {
- VLOG(4) << "Add dead tensors source. Function call: " << node.op()
- << " node=" << node.name();
- dead_tensor_srcs.push_back(&node);
- }
- }
- }
-
- // If we do not have dead tensor sources in the function body, it's
- // guaranteed that all output tensors can't become dead.
- if (dead_tensor_srcs.empty()) return Status::OK();
-
- // Names of the function body nodes that return function output values.
- absl::flat_hash_set<absl::string_view> output_nodes;
- for (const auto& output_arg : item.outputs()) {
- output_nodes.insert(output_arg.node_name);
- }
-
- GraphTopologyView topology_view;
- TF_RETURN_IF_ERROR(topology_view.InitializeFromGraph(item.graph));
-
- for (const NodeDef* dead_tensor_src : dead_tensor_srcs) {
- DfsTraversal(topology_view, {dead_tensor_src},
- TraversalDirection::kFollowOutputs,
- // Stop traversal when reached first `Merge` node.
- DfsPredicates::Advance(
- [](const NodeDef* node) { return !IsMerge(*node); }),
- // If we reached output node, add MaybeDeadOutput edge.
- DfsCallbacks::PreOrder([&](const NodeDef* node) {
- if (output_nodes.find(node->name()) != output_nodes.end()) {
- maybe_dead->push_back({dead_tensor_src, node});
- }
- }));
- }
-
- return Status::OK();
-}
-
-// Returns `Status::OK()` iff `node` is an indirect function call of `func`, and
-// we know how to inline it into the main graph, otherwise returns and error
-// indicating why the function call is not inlinable.
-Status IsInlinableIndirectFunctionCall(const FunctionOptimizerContext& ctx,
- const FunctionDef& func,
- const NodeDef& func_node) {
- // We inline direct function calls above, using different rules.
- if (!IsIndirectFunctionCall(func, func_node)) {
- return errors::InvalidArgument("Unsupported function call type: ",
- SummarizeNodeDef(func_node));
- }
-
- if (MarkedNoInline(func)) {
- return errors::FailedPrecondition(
- "Can't inline function marked with '_noinline': ",
- SummarizeNodeDef(func_node));
- }
-
- // Function specialization and inlining must be mutually exclusive.
- if (MarkedSpecialized(func)) {
- return errors::FailedPrecondition(
- "Can't inline function created in Grappler function specialization: ",
- SummarizeNodeDef(func_node));
- }
-
- // We can't inline functions that are in a fetch set, because it would
- // invalidate fetch tensors (function call node fully inlined and doesn't
- // exist in the optimized graph).
- if (ctx.IsFetchNode(func_node.name())) {
- return errors::FailedPrecondition(
- "Can't inline function in a Grappler item fetch set: ",
- SummarizeNodeDef(func_node));
- }
-
- return Status::OK();
-}
-
-// Checks that all side-effects will be executed in well defined order. We do it
-// by checking if there is a path from stateful/dataset ops to one of the
-// control output nodes.
-Status CheckThatSideEffectsWillExecute(
- const FunctionOptimizerContext& ctx,
- const GraphTopologyView& graph_topo_view,
- const absl::flat_hash_set<string> control_output_nodes) {
- // In aggressive mode we just print a warning for side-effectful nodes that
- // might not be executed after inlining.
- const bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
-
- for (const NodeDef& func_body_node : graph_topo_view.graph()->node()) {
- const bool node_must_execute =
- IsDataset(func_body_node) ||
- IsStateful(func_body_node, &ctx.function_library());
-
- // If op has DT_RESOURCE argument it will be marked as stateful, though if
- // it only reads from that resource, it's allowed to prune it, because it
- // can't produce any visible side-effects.
- const bool read_only = IsReadVariableOp(func_body_node);
-
- // _Retval marked as stateful, but we will remove it before inlining.
- const bool retval = IsRetval(func_body_node);
-
- if (read_only || retval || !node_must_execute) continue;
-
- VLOG(3) << "Check that node " << func_body_node.name()
- << " will execute after inlining.";
- bool will_execute = false;
-
- // Check if we reached one of the output nodes.
- const auto callbacks = DfsCallbacks::PreOrder([&](const NodeDef* node) {
- if (control_output_nodes.contains(node->name())) {
- VLOG(4) << "Found a path to control output node: " << node->name();
- will_execute = true;
- }
- });
-
- // Stop if we already proved that node will execute.
- const auto predicates = DfsPredicates::Enter(
- [&](const NodeDef* node) { return !will_execute; });
-
- DfsTraversal(graph_topo_view, {&func_body_node},
- TraversalDirection::kFollowOutputs, predicates, callbacks);
-
- if (!will_execute) {
- const string error_message = absl::StrCat(
- "Can't guarantee execution of a side-effectful node, that is not "
- "reachable from function outputs. Function body node: ",
- SummarizeNodeDef(func_body_node));
-
- if (aggressive) {
- LOG(WARNING) << error_message;
- } else {
- return errors::Internal(error_message);
- }
- }
- }
-
- return Status::OK();
-}
-
-Status PlaceInlinedFunctionBody(
- const NodeDef& func_node, const GrapplerFunctionItem& item,
- const absl::flat_hash_map<absl::string_view, int>& input_args_idx,
- FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) {
- // Control flow lowering and Placer works with a Graph object.
- std::unique_ptr<Graph> func_body_graph =
- absl::make_unique<Graph>(ctx->function_library());
-
- GraphConstructorOptions opts;
- TF_RETURN_IF_ERROR(
- ConvertGraphDefToGraph(opts, item.graph, func_body_graph.get()));
+ VLOG(4) << "Inlined " << inlined_function_names.size()
+ << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
// ------------------------------------------------------------------------ //
// Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
- // passes, so each node has a valid device assignment. Also V2 control
- // flow ops (functional If and While) should have been lowered to V1 control
- // flow (Switch and Merge nodes). To keep the graph valid for execution we
- // must assign device to every inlined graph node, and also lower the control
- // flow.
+ // passes, so each node has a valid device assignment. After function inlining
+ // and control flow V2 lowering we have to keep graph placed.
- GraphOptimizationPassOptions opt_options;
- opt_options.graph = &func_body_graph;
- opt_options.flib_def = ctx->mutable_function_library();
+ if (inlined_function_names.empty()) {
+ VLOG(3) << "Not placing graph after function inlining"
+ << " (did not inline any of the function calls).";
- // TODO(ezhulenev): Should we run full PRE_PLACEMENT pass here? And
- // POST_PLACEMENT after placer?
- LowerFunctionalOpsPass pass(/*lower_function_calls=*/false,
- /*keep_lowered_nodes_fetchable=*/false);
- TF_RETURN_IF_ERROR(pass.Run(opt_options));
+ } else if (item.devices().empty()) {
+ // If there are no devices available for placer, we do not place graph after
+ // function inlining. This happens when Grappler is optimizing the function
+ // library, or when a graph optimized "offline", without an active runtime
+ // session, for example as a part of batch job for graph
+ // analysis/optimization. GrapplerItem instantiated from a function library
+ // doesn't have to be fully placed after all optimizations; it will be
+ // placed by the function library runtime before execution.
+ VLOG(3) << "Not placing graph after function inlining"
+ << " (device set is empty)";
- // ------------------------------------------------------------------------ //
- // Before placing the function body nodes we pin input arguments to the
- // same device as their corresponding input nodes.
-
- for (Node* func_body_node : func_body_graph->nodes()) {
- const auto input_arg_idx = input_args_idx.find(func_body_node->name());
-
- if (input_arg_idx != input_args_idx.end()) {
- const int input_idx = input_arg_idx->second;
- const GraphView::OutputPort output_port =
- ctx->graph_view().GetRegularFanin({&func_node, input_idx});
-
- const string& input_device = output_port.node->device();
-
- if (!input_device.empty()) {
- VLOG(3) << "Pin inlined function input node '" << func_body_node->name()
- << "' to the '" << output_port.node->device() << "' device.";
- func_body_node->set_requested_device(output_port.node->device());
- } else {
- VLOG(3) << "Inlined function input node '" << func_body_node->name()
- << "' device is undefined.";
- }
- }
- }
-
- // ------------------------------------------------------------------------ //
- // After placing nodes corresponding to the function inputs, we need to assign
- // device placements to all other function body nodes.
-
- const DeviceSet* devices = ctx->devices();
-
- if (devices->devices().empty()) {
- // If there are no devices available for placer, we do not place function
- // body nodes. This happens when Grappler optimizing function library, or
- // when graph optimized "offline", without active runtime session, for
- // example as a part of batch job for graph analysis/optimization.
- // GrapplerItem instantiated from a function library doesn't have to be
- // fully placed after all optimization, it will be placed by the function
- // library runtime before execution.
- VLOG(3) << "Do not place instantiated function body.";
} else {
// If we are running in an active runtime session, Grappler will get the
// graph after initial placing is done, and we should have devices for the
// placer.
- VLOG(3) << "Run placer for instantiated function body. Devices: ["
- << absl::StrJoin(
- devices->devices(), ", ",
- [](string* out, const Device* d) { out->append(d->name()); })
- << "]";
+ VLOG(3) << "Run placer for the graph after function inlining. "
+ << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
- // Use function caller node device as a default for placer.
- const Device* default_device =
- devices->FindDeviceByName(func_node.device());
+ DeviceSet device_set; // does not own devices
+ std::vector<std::unique_ptr<Device>> fake_devices; // owns fake devices
- Placer placer(func_body_graph.get(), item.id, devices, default_device);
+ for (const string& name : item.devices()) {
+ auto device = absl::make_unique<FakeDevice>(name);
+ device_set.AddDevice(device.get());
+ fake_devices.push_back(std::move(device));
+ }
+
+ Placer placer(graph.get(), item.id, &device_set);
TF_RETURN_IF_ERROR(placer.Run());
}
- // Convert Graph back to the placed GraphDef.
- func_body_graph->ToGraphDef(placed_graph_def);
-
+ graph->ToGraphDef(output_graph);
return Status::OK();
}
-Status InlineIndirectFunctionCall(const NodeDef& func_node,
- const FunctionDef& func,
- FunctionOptimizerContext* ctx,
- GraphDef* optimized_graph) {
- VLOG(2) << "Inline indirect function call: " << SummarizeNodeDef(func_node);
- VLOG(4) << "Inlined function definition: " << DebugString(func);
- TF_RETURN_IF_ERROR(IsInlinableIndirectFunctionCall(*ctx, func, func_node));
+// Restores tensor mapping after function specialization: all inputs must be
+// connected to valid nodes.
+void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
+ GraphDef* optimized_graph) {
+ if (ctx.tensor_mapping().empty()) return;
- const AttrSlice func_instantiation_attr =
- FunctionInstantiationAttributes(func, func_node);
-
- GrapplerFunctionItem item;
- Status item_status = MakeGrapplerFunctionItem(func, func_instantiation_attr,
- ctx->function_library(),
- ctx->graph_version(), &item);
-
- if (!item_status.ok()) {
- return errors::InvalidArgument("Failed to inline function ", func_node.op(),
- " instantiated by ", func_node.name(),
- ". Error: ", item_status.error_message());
- }
-
- // `PartitionedCallOp` invokes functions with `allow_dead_tensors = true` to
- // reset dead flag, and return default initialized tensors instead of a dead
- // tensors. There is no way to express this in a regular Tensorflow graph, so
- // we choose not to inline if a function can have dead tensors as an output
- // position. In practice `mergeless switches` should not exists in a function
- // body, because tf-eager will only use v2 control flow ops.
- std::vector<MaybeDeadOutput> maybe_dead_outputs;
- TF_RETURN_IF_ERROR(MaybeDeadOutputs(*ctx, item, &maybe_dead_outputs));
- if (!maybe_dead_outputs.empty()) {
- struct MaybeDeadOutputFormatter {
- void operator()(string* out, const MaybeDeadOutput& md) const {
- absl::StrAppend(out, SummarizeNodeDef(*md.dead_tensor_src));
- }
- };
- return errors::FailedPrecondition(
- "Can't inline function with dead outputs. Dead tensor sources (size = ",
- maybe_dead_outputs.size(), "): ",
- absl::StrJoin(maybe_dead_outputs, "\n", MaybeDeadOutputFormatter()));
- }
-
- GraphView::InputPort control_input_port =
- ctx->graph_view().GetInputPort(func_node.name(), Graph::kControlSlot);
- GraphView::OutputPort control_output_port =
- ctx->graph_view().GetOutputPort(func_node.name(), Graph::kControlSlot);
-
- // Nodes that have side effects to the captured resources.
- std::vector<string> happens_before;
- absl::c_transform(
- ctx->graph_view().GetFanin(control_input_port),
- std::back_inserter(happens_before),
- [](const GraphView::OutputPort port) { return port.node->name(); });
-
- VLOG(3) << "Happens before set (size = " << happens_before.size()
- << "): " << absl::StrJoin(happens_before, ", ");
-
- // Nodes that must observe side effects to the captured resources.
- std::vector<string> happens_after;
- absl::c_transform(
- ctx->graph_view().GetFanout(control_output_port),
- std::back_inserter(happens_after),
- [](const GraphView::InputPort port) { return port.node->name(); });
-
- VLOG(3) << "Happens after set (size = " << happens_after.size()
- << "): " << absl::StrJoin(happens_after, ", ");
-
- // Regular (data) inputs to the function call.
- std::vector<SafeTensorId> inputs;
- for (const string& input : func_node.input()) {
- SafeTensorId tensor_id = ParseTensorName(input);
- if (tensor_id.index() == Graph::kControlSlot) break;
- inputs.push_back(tensor_id);
- }
-
- // Mapping from input argument node to function input position.
- absl::flat_hash_map<absl::string_view, int> input_args_idx;
- for (const InputArgInstantiation& input_arg : item.inputs()) {
- const int idx = input_args_idx.size();
- input_args_idx[input_arg.node_name] = idx;
- }
-
- const string prefix = strings::StrCat(func_node.name(), "/");
-
- // ------------------------------------------------------------------------ //
- // IMPORTANT: Actual inputs will be added to the following nodes at the very
- // last stage, because we don't want to have invalid edges in a function body
- // graph (control edges that depend on the nodes in the "outer" optimized
- // graph).
-
- // If one of the function inputs is a dead tensor, we must not execute any of
- // the function body nodes, and let the dead tensor flag propagate through the
- // inlined function body. We add NoOp inputs_ready node, and add control edges
- // to it from all input nodes. Inlined function arguments (Identity nodes)
- // will have a control dependency on it.
+ // During function specialization, we might prune unused function outputs. We
+ // need to "close the holes" that might appear in the function outputs.
//
- // TODO(ezhulenev): We do not need to provide this guarantee for ALL nodes in
- // the function body. We must only ensure that we do not generate observable
- // side effects.
+ // Example: prune unused output "f:1"
//
- // If the function call node has incoming control edges, we will update them
- // to use this node as destination, to ensure side-effects execution order.
- NodeDef* inputs_ready_node = nullptr;
- if (func_node.input_size() > 0) {
- inputs_ready_node = item.graph.add_node();
- inputs_ready_node->set_op("NoOp");
- inputs_ready_node->set_name(kInputsReadyNodeName);
- }
-
- // All nodes that have a control edge from the function call node, will be
- // updated to have a control edge from 'side_effects_executed_node`. This node
- // will have control edges from all function control outputs (see
- // `control_ret` in FunctionDef). This a "barrier" that guarantees that all
- // ops with side effects in the function body were executed
+ // f = my_func[T=float](...) f = my_func_specialized[T=float](...)
+ // a = Identity(f:0) -> a = Identity(f:0)
+ // b = Identity(f:2) b = Identity(f:1)
//
- // If the function call node has no outgoing control edges, it means that no
- // one is interested in the function side-effect affecting captured resources.
- //
- // If node is in keep_ops set, it means that it must execute. This could
- // happen if the graph is an instantiation of a function with control output.
- NodeDef* side_effects_executed_node = nullptr;
- if (!happens_after.empty() || ctx->IsKeepOp(func_node.name())) {
- side_effects_executed_node = item.graph.add_node();
- side_effects_executed_node->set_op("NoOp");
- side_effects_executed_node->set_name(kSideEffectsExecutedNodeName);
- }
+ // Tensor mapping (size=1): [f:2 -> f:1]
+ for (NodeDef& node : *optimized_graph->mutable_node()) {
+ for (int idx = 0; idx < node.input_size(); ++idx) {
+ TensorId input_tensor = ParseTensorName(node.input(idx));
+ if (input_tensor.index() == Graph::kControlSlot) break;
- // If function executed only for the regular data outputs, it's totally safe
- // to prune side-effects. If side-effects order is important, it must be
- // captured at graph construction time via control edges.
- if (item.control_output_size() > 0 && happens_after.empty()) {
- VLOG(2) << "Function has control outputs and empty happens after set.";
- }
-
- // ------------------------------------------------------------------------ //
- // If we have a node inside the function body without inputs (e.g. Const), we
- // must attach a control dependency to it, to make sure that if a function
- // call happens inside a loop, the node will be evaluated in correct frame.
- //
- // If the function call node has no inputs and no control dependencies, it
- // means that it can't be a function call inside a loop, and we can safely
- // insert that node without inputs into the main graph.
- //
- // TODO(ezhulenev): Use FrameMap (see grappler/utils/frame.h) to find out if
- // the function is called inside a loop.
- std::vector<string> empty_inputs_hook;
- if (inputs_ready_node != nullptr) {
- empty_inputs_hook.push_back(inputs_ready_node->name());
- }
-
- // ------------------------------------------------------------------------ //
- // Grappler called after PRE_PLACEMENT and PLACEMENT passes, so we have to
- // make sure that after inlining all nodes will have valid device assignment.
-
- GraphDef placed_graph_def;
- TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(func_node, item, input_args_idx,
- ctx, &placed_graph_def));
-
- // ------------------------------------------------------------------------ //
- // Mapping from the '_Retval' node name to the output tensor. We build this
- // mapping after the placement, because we might have inlined some of the
- // functional If/While nodes (see a call to LowerFunctionalOpsPass).
- absl::flat_hash_map<string, string> output_tensors;
-
- for (const NodeDef& func_body_node : placed_graph_def.node()) {
- if (!IsRetval(func_body_node)) continue;
- if (func_body_node.input_size() != 1) {
- return errors::Internal("_Retval node must have single input: ",
- SummarizeNodeDef(func_body_node));
- }
- output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
- }
-
- // ------------------------------------------------------------------------ //
- // After all nodes placed we need to prepare them for inlining into the
- // optimized graph: turn placeholders into identities, update nodes
- // connectivity, etc...
-
- const auto inlined_node_name = [&func_node](const string& name) -> string {
- return AddPrefixToNodeName(name, /*prefix=*/func_node.name());
- };
-
- for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
- const string& node_name = func_body_node.name();
-
- // Turn _Arg nodes added in place of input arguments into identity nodes.
- const auto input_arg_idx = input_args_idx.find(node_name);
- if (input_arg_idx != input_args_idx.end()) {
- DCHECK_EQ(0, func_body_node.input_size());
- func_body_node.set_op("Identity");
- func_body_node.mutable_attr()->erase("index");
- func_body_node.mutable_attr()->erase("shape");
- const int input_idx = input_arg_idx->second;
- func_body_node.add_input(inputs[input_idx].ToString());
-
- // Add a control dependency on 'inputs_ready' node, to guarantee that all
- // inputs are alive and all side-effects executed before function body.
- if (inputs_ready_node) {
- func_body_node.add_input(
- AsControlDependency(inlined_node_name(inputs_ready_node->name())));
- }
- } else {
- // Update inputs of the regular function body nodes.
- for (string& input : *func_body_node.mutable_input()) {
- input = inlined_node_name(input);
- }
-
- // Check if we need to ensure node execution in correct loop frame.
- bool node_needs_empty_inputs_hook =
- // We have a node to hook and node has no inputs.
- !empty_inputs_hook.empty() && func_body_node.input_size() == 0 &&
- // Inputs ready node will always have edge from main graph. If
- // function call has no regular and control inputs, we will not add
- // inputs_ready node to the function body graph.
- node_name != kInputsReadyNodeName &&
- // The node acting as a return barrier for execution of side effects
- // might not have any inputs (in case function has no control outputs,
- // but we still added it because of non-empty happens-after set), so
- // we must make sure it's executed in correct frame.
- (node_name != kSideEffectsExecutedNodeName ||
- item.control_output_size() == 0);
-
- if (node_needs_empty_inputs_hook) {
- *func_body_node.add_input() =
- AsControlDependency(inlined_node_name(empty_inputs_hook[0]));
- }
- }
-
- // Add the function node name as a prefix 1) to node name to avoid
- // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
- // frame after inlining.
- TF_RETURN_IF_ERROR(
- AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
-
- // After inlining into the optimized graph, NodeDef must have all attributes
- // defined, which is not required for a node in a FunctionDef.
- const OpDef* op_def;
- TF_RETURN_IF_ERROR(
- ctx->function_library().LookUpOpDef(func_body_node.op(), &op_def));
- AddDefaultsToNodeDef(*op_def, &func_body_node);
- }
-
- // ------------------------------------------------------------------------ //
- // Check that after inlining all side-effects will be executed in well defined
- // order. We do it by checking if there is a path from stateful/dataset ops to
- // one of the control output nodes.
-
- // Names of the inlined control output nodes.
- absl::flat_hash_set<string> inlined_control_output_nodes;
- for (const ControlOutput& control_output : item.control_outputs()) {
- inlined_control_output_nodes.insert(
- inlined_node_name(control_output.node_name));
- }
-
- // Construct a graph topology view for DFS traversals (skip invalid edges for
- // input nodes connected to nodes in the optimized graph).
- GraphTopologyView placed_topo_view(/*skip_invalid_edges=*/true);
- TF_RETURN_IF_ERROR(placed_topo_view.InitializeFromGraph(placed_graph_def));
- TF_RETURN_IF_ERROR(CheckThatSideEffectsWillExecute(
- *ctx, placed_topo_view, inlined_control_output_nodes));
-
- // ------------------------------------------------------------------------ //
- // Move all the nodes to the optimized graph after successful preprocessing.
-
- if (inputs_ready_node != nullptr) {
- string inlined_node = inlined_node_name(inputs_ready_node->name());
- absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
-
- absl::flat_hash_set<string> input_nodes;
- for (const string& input : func_node.input()) {
- SafeTensorId tensor = ParseTensorName(input);
-
- // Input node might have been a function call that was already inlined.
- auto it = ctx->tensor_mapping().find(tensor);
- while (it != ctx->tensor_mapping().end()) {
- tensor = it->second;
- it = ctx->tensor_mapping().find(tensor);
- }
-
- if (input_nodes.insert(tensor.node()).second) {
- placed_graph_def.mutable_node(*node_idx)->add_input(
- AsControlDependency(tensor.node()));
+ auto mapping = ctx.tensor_mapping().find(input_tensor);
+ if (mapping != ctx.tensor_mapping().end()) {
+ node.set_input(idx, mapping->second.ToString());
}
}
}
-
- if (side_effects_executed_node != nullptr) {
- string inlined_node = inlined_node_name(side_effects_executed_node->name());
- absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
-
- // Add control edges from all control output nodes.
- for (const string& node_name : inlined_control_output_nodes) {
- placed_graph_def.mutable_node(*node_idx)->add_input(
- AsControlDependency(node_name));
- }
-
- // Forward all control dependencies in the optimized graph to the new node.
- ctx->AddControlOverrides(func_node, {inlined_node});
- }
-
- for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
- // We bypass _Retval nodes and fetch tensors from `retval.input(0)`.
- if (IsRetval(func_body_node)) continue;
- optimized_graph->add_node()->Swap(&func_body_node);
- }
-
- // Indirect function call is fully inlined into the optimized graph, and we do
- // not copy the original function call node, so we have to setup tensor
- // mapping from old output tensors, to the outputs of inlined nodes.
- int output_idx = 0;
- for (const OutputArgInstantiation& output : item.outputs()) {
- const string& output_tensor = output_tensors.at(output.node_name);
-
- const SafeTensorId from_tensor(func_node.name(), output_idx++);
- const SafeTensorId to_tensor = ParseTensorName(output_tensor);
-
- const SafeTensorId inlined_to_tensor =
- SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
- to_tensor.index());
-
- ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
- }
-
- // If function call node was in keep_ops set, it means that we need to keep a
- // node with the same name in the optimized graph. We forward all data
- // consumers to inlined nodes, and we verify that the node is not in a fetch
- // set, so it's safe to assume that the function call node is only required
- // for a control edge source.
- if (ctx->IsKeepOp(func_node.name())) {
- VLOG(4) << "Add NoOp for inlined function in keep ops set.";
- NodeDef* keep_func_node = optimized_graph->add_node();
- keep_func_node->set_op("NoOp");
- keep_func_node->set_name(func_node.name());
- keep_func_node->set_device(func_node.device());
- keep_func_node->add_input(
- AsControlDependency(inlined_node_name(kSideEffectsExecutedNodeName)));
- }
-
- VLOG(3) << "Successfully inlined indirect function call: "
- << SummarizeNodeDef(func_node);
-
- return Status::OK();
-}
-
-// Restores graph invariants after function specialization and inlining: all
-// inputs must be connected to valid nodes.
-Status RestoreGraphInvariants(const FunctionOptimizerContext& ctx,
- GraphDef* optimized_graph) {
- // After function specialization and inlining graph might be in invalid
- // state, and some nodes can read tensors that do not exists anymore in the
- // optimized graph: function call node was fully inlined into the graph, or
- // output index was invalidated by the output pruning.
-
- if (!ctx.tensor_mapping().empty()) {
- for (NodeDef& node : *optimized_graph->mutable_node()) {
- for (int idx = 0; idx < node.input_size(); ++idx) {
- TensorId input_tensor = ParseTensorName(node.input(idx));
- if (input_tensor.index() == Graph::kControlSlot) break;
-
- auto mapping = ctx.tensor_mapping().find(input_tensor);
- if (mapping != ctx.tensor_mapping().end()) {
- node.set_input(idx, mapping->second.ToString());
- }
- }
- }
- }
-
- // Function inlining instantiates function body directly into the optimized
- // graph, and we might end up with control dependencies to the nodes that no
- // longer exist in a graph. We need to apply control overrides to all
- // invalidated nodes, and rewire control dependencies to the control outputs
- // node (it's also possible to rewrite singe control edge into multiple edges
- // to inlined side-effectful nodes).
-
- if (!ctx.control_overrides().empty()) {
- for (NodeDef& node : *optimized_graph->mutable_node()) {
- // Keep track of new control inputs to the node.
- absl::flat_hash_set<string> add_ctrl_inputs;
-
- // Remove all invalidated control inputs.
- for (int idx = 0; idx < node.input_size(); /* see below */) {
- // TODO(ezhulenev): Use non-allocating TensorId after migrating
- // `control_overrides()` to absl::flat_hash_set.
- SafeTensorId input_tensor = ParseTensorName(node.input(idx));
-
- auto overrides = ctx.control_overrides().find(input_tensor.node());
- if (overrides != ctx.control_overrides().end()) {
- // If this happens it's a bug in the function inlining.
- if (input_tensor.index() != Graph::kControlSlot) {
- return errors::Internal(
- "Illegal input edge from inlined function call node");
- }
- // Remove control dependency to the inlined function call node.
- node.mutable_input()->SwapElements(idx, node.input_size() - 1);
- node.mutable_input()->RemoveLast();
-
- // Keep track of all overrides.
- for (const string& override : overrides->second) {
- add_ctrl_inputs.insert(AsControlDependency(override));
- }
- } else {
- // Go to the next input only if the current one was not invalidated,
- // otherwise we need to check the swapped input as well.
- ++idx;
- }
- }
-
- // Add overrides to the node inputs.
- for (const string& ctrl_input : add_ctrl_inputs) {
- node.add_input(ctrl_input);
- }
- }
- }
-
- return Status::OK();
}
} // namespace
Status FunctionOptimizer::RunFunctionOptimizerPass(
- const GrapplerItem& item, const GraphDef& graph, const int iteration,
- std::unordered_set<string>* skip_nodes, GraphDef* optimized_graph,
- bool* graph_has_unoptimized_function_calls) const {
- VLOG(3) << absl::Substitute(
- "Run function optimizer pass (iteration = $0): grappler_item_id = $1",
- iteration, item.id);
+ const GrapplerItem& item, GraphDef* optimized_graph) const {
+ VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
// Inline all function calls into a graph using common_runtime/function
// implementation (see `InlineFunctionBody` function documentation).
GraphDef graph_after_inlining;
- TF_RETURN_IF_ERROR(InlineFunctionCalls(
- item, FunctionLibraryDefinition(OpRegistry::Global(), graph.library()),
- graph, skip_nodes, &graph_after_inlining));
+ TF_RETURN_IF_ERROR(
+ InlineFunctionCalls(item, opt_level_, &graph_after_inlining));
+ // Specialize function calls that we could not inline.
FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
- bool inline_gradients = options_.enable_symbolic_gradient_inlining;
- bool inline_func = options_.enable_function_inlining;
- bool specialize_func = options_.enable_function_specialization;
-
- // We will process all the nodes in topological order, to correctly handle
- // inlining of function call chains.
- std::vector<const NodeDef*> topo_ordered_nodes;
- TF_RETURN_IF_ERROR(
- ComputeTopologicalOrder(graph_after_inlining, &topo_ordered_nodes));
-
- for (const NodeDef* node : topo_ordered_nodes) {
- // Each node optimization can modify optimized graph only by adding new
+ for (const NodeDef& node : graph_after_inlining.node()) {
+ // Function specialization can modify optimized graph only by adding new
// nodes, we can check node size to make sure that graph was not modified.
const int num_nodes_before = optimized_graph->node_size();
const auto is_graph_modified = [&]() {
@@ -1925,116 +1350,50 @@
return num_nodes > num_nodes_before;
};
- // Copy node from the `graph` to the `optimized_graph`.
- const auto copy_node = [&]() { *optimized_graph->add_node() = *node; };
+ // Copy node from the `graph_after_inlining` to the `optimized_graph`.
+ const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
- // If we already failed to optimize this node during one of the previous
- // passes, we just give up, and do not try on more time.
- if (skip_nodes->find(node->name()) != skip_nodes->end()) {
- VLOG(3) << "Skip optimization for node: " << node->name();
+ // Find if a node is a function call (direct or indirect).
+ const FunctionDef* func = FindFunctionCall(ctx, node);
+ if (func == nullptr) {
copy_node();
continue;
}
-// Skip errors if optimized graph was not modified before error happened.
-#define TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(...) \
- do { \
- const Status _status = (__VA_ARGS__); \
- if (TF_PREDICT_FALSE(!_status.ok() && is_graph_modified())) \
- return _status; \
- if (TF_PREDICT_FALSE(!_status.ok() && !is_graph_modified())) { \
- VLOG(3) << "Skip error: " << _status.error_message(); \
- skip_nodes->insert(node->name()); \
- copy_node(); \
- } \
- } while (0)
+ const string& func_name = func->signature().name();
- // ---------------------------------------------------------------------- //
- // Inline or specialize function calls. //
- // ---------------------------------------------------------------------- //
+ // Specialize it to its instantiation context if it has something worth
+ // specializing.
+ bool specialization_worthy = IsParametrized(*func) ||
+ HasTrulyConstInputs(node, ctx) ||
+ HasUnusedOutputs(node, *func, ctx);
+ // Do not specialize if function has custom gradient.
+ const string grad_func = ctx.function_library().FindGradient(func_name);
- // Find if a node is a function call (direct or indirect).
- const FunctionDef* func = FindFunctionCall(ctx, *node);
-
- if (func != nullptr) {
- const string& func_name = func->signature().name();
-
- const bool is_indirect_func = IsIndirectFunctionCall(*func, *node);
-
- // Inline indirect function call if it's inlinable.
- if (inline_func && is_indirect_func) {
- Status inlinable = IsInlinableIndirectFunctionCall(ctx, *func, *node);
- if (inlinable.ok()) {
- TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- InlineIndirectFunctionCall(*node, *func, &ctx, optimized_graph));
- continue;
- } else {
- VLOG(2) << inlinable.error_message();
- skip_nodes->insert(node->name());
- }
+ if (grad_func.empty() && specialization_worthy) {
+ // TODO(ezhulenev): Specialize function call if input has a known shape.
+ // Specialize function body for its instantiation attributes and inputs.
+ Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
+ if (!status.ok() && is_graph_modified()) {
+ return status;
+ } else if (!status.ok() && !is_graph_modified()) {
+ VLOG(3) << "Skip specialization error: " << status.error_message();
+ copy_node();
}
-
- // Specialize it to its instantiation context if can't be inlined,
- // and it has something worth specializing.
- bool specialization_worthy = IsParametrized(*func) ||
- HasTrulyConstInputs(*node, ctx) ||
- HasUnusedOutputs(*node, *func, ctx);
-
- // Do not specialize if function has custom gradient.
- const string grad_func = ctx.function_library().FindGradient(func_name);
-
- if (specialize_func && grad_func.empty() && specialization_worthy) {
- // TODO(ezhulenev): Specialize function call if input has a known shape.
- // Specialize function body for its instantiation attributes and inputs.
- TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- SpecializeFunction(*node, *func, &ctx, optimized_graph));
- continue;
- } else {
- VLOG(2) << "Skip function specialization: " << func->signature().name();
- skip_nodes->insert(node->name());
- }
+ continue;
+ } else {
+ VLOG(2) << "Skip function specialization: " << func->signature().name();
+ copy_node();
}
-
- // ---------------------------------------------------------------------- //
- // If we reached this point, node was not handled by any of the stages
- // (inline, specialize), simply copy the node to the optimized graph.
- copy_node();
-
-#undef TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED
}
- TF_RETURN_IF_ERROR(RestoreGraphInvariants(ctx, optimized_graph));
+ RestoreTensorMapping(ctx, optimized_graph);
// Preserve the graph version.
- *optimized_graph->mutable_versions() = graph.versions();
-
+ *optimized_graph->mutable_versions() = item.graph.versions();
// Prune unreachable function from the library.
- if (options_.enable_trim_function_library) {
- *optimized_graph->mutable_library() =
- PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
- } else {
- *optimized_graph->mutable_library() = ctx.function_library().ToProto();
- }
-
- // Before returning we check if after single optimization pass we have more
- // unoptimized function calls.
- *graph_has_unoptimized_function_calls = false;
- for (const NodeDef& node : optimized_graph->node()) {
- // Check if we can inline symbolic gradient.
- if (IsSymbolicGradient(node) && inline_gradients &&
- skip_nodes->count(node.name()) == 0) {
- *graph_has_unoptimized_function_calls = true;
- break;
- }
-
- // Check if after inlining we have unoptimized function calls.
- const FunctionDef* func = FindFunctionCall(ctx, node);
- if (func != nullptr && !MarkedSpecialized(*func) &&
- skip_nodes->count(node.name()) == 0) {
- *graph_has_unoptimized_function_calls = true;
- break;
- }
- }
+ *optimized_graph->mutable_library() =
+ PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
return Status::OK();
}
@@ -2047,35 +1406,7 @@
return Status::OK();
}
- // Do not retry failed function inlining or specialization.
- std::unordered_set<string> skip_nodes;
- bool graph_has_unoptimized_function_calls = false;
-
- // We'll keep running function optimizer pass until we inlined and optimized
- // all function call nodes.
- int iteration = 0;
- constexpr int kMaxIterations = 3;
-
- // 1. Run first optimizer pass with GrapplerItem.graph.
- TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
- item, item.graph, 0, &skip_nodes, optimized_graph,
- &graph_has_unoptimized_function_calls));
-
- // 2. If after function inlining we have unoptimized function calls, we have
- // to run function optimization pass one more time.
- while (graph_has_unoptimized_function_calls) {
- if (iteration++ > kMaxIterations) {
- VLOG(1) << "Break function optimizer loop at iteration #" << iteration;
- break;
- }
-
- GraphDef workspace_graph;
- workspace_graph.Swap(optimized_graph);
-
- TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
- item, workspace_graph, iteration, &skip_nodes, optimized_graph,
- &graph_has_unoptimized_function_calls));
- }
+ TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.h b/tensorflow/core/grappler/optimizers/function_optimizer.h
index ab90281..8c96bbc 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.h
@@ -41,25 +41,15 @@
private:
friend class FunctionOptimizerTest;
- struct FunctionOptimizerOptions {
- bool enable_function_inlining = true;
- bool enable_function_specialization = true;
- bool enable_symbolic_gradient_inlining = true;
- bool enable_trim_function_library = true;
- };
-
// Runs a single function optimizer pass over the `graph`. All nodes that are
// not function calls will be copied from the `graph` to the
// `optimized_graph`. Function call nodes inlined or specialized, and
// instantiated function body or specialized function call nodes will be added
// to the `optimized_graph`.
- Status RunFunctionOptimizerPass(
- const GrapplerItem& item, const GraphDef& graph, const int iteration,
- std::unordered_set<string>* skip_nodes, GraphDef* optimized_graph,
- bool* graph_has_unoptimized_function_calls) const;
+ Status RunFunctionOptimizerPass(const GrapplerItem& item,
+ GraphDef* optimized_graph) const;
RewriterConfig::Toggle opt_level_;
- FunctionOptimizerOptions options_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 828827a..1455399 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -33,12 +33,7 @@
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
} // namespace
-class FunctionOptimizerTest : public GrapplerTest {
- protected:
- void DisableFunctionSpecialization(FunctionOptimizer* optimizer) {
- optimizer->options_.enable_function_specialization = false;
- }
-};
+class FunctionOptimizerTest : public GrapplerTest {};
TEST_F(FunctionOptimizerTest, InlineFunction_SimpleFunction) {
using test::function::NDef;
@@ -257,7 +252,6 @@
using test::function::NDef;
FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
- DisableFunctionSpecialization(&optimizer); // do not specialize noinline func
const Tensor kTwo = test::AsScalar<int64>(2);
FunctionDef func = FunctionDefHelper::Define(
@@ -513,6 +507,10 @@
item.feed.emplace_back("a", pi);
item.feed.emplace_back("b", pi);
+ const string input_x = "Func/c/input/_0";
+ const string input_y = "Func/c/input/_1";
+ const string output_z = "Func/c/output/_2";
+
// If device set is empty, inlined function body must not be placed.
{
GraphDef optimized_graph;
@@ -524,14 +522,14 @@
// Function body nodes are not placed, however function input nodes
// must copy device assignment from input arguments.
- NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}),
- NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}),
+ NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
+ // TODO(ezhulenev): Currently inlined function body "implicitly placed"
+ // with a 'inline_options.initialize_empty_device' flag.
+ NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
+ NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, kDevice),
- NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+ NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
// Function library.
{mul_func});
@@ -555,14 +553,12 @@
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
- NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
- NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
+ NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, kDevice),
- NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+ NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
// Function library.
{mul_func});
@@ -650,54 +646,68 @@
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
// Initialize variable with one of the placeholders.
- NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}}),
+ NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}},
+ kDevice),
NDef("init_v", "AssignVariableOp", {"v", "a"}, {{"dtype", DT_FLOAT}},
kDevice),
// Function body of a first function call inlined into the graph.
- NDef("f1/inputs_ready", "NoOp", {"^a", "^b", "^v", "^init_v"}, {},
+ NDef("Func/f1/input_control_node/_0", "NoOp", {"^init_v"}, {}, kDevice),
+
+ NDef("Func/f1/input/_1", "Identity", // input: 'x'
+ {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
+ kDevice),
+ NDef("Func/f1/input/_2", "Identity", // input: 'y'
+ {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
+ kDevice),
+ NDef("Func/f1/input/_3", "Identity", // input: 'v'
+ {"v", "^Func/f1/input_control_node/_0"}, {{"T", DT_RESOURCE}},
kDevice),
- NDef("f1/x", "Identity", {"a:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("f1/y", "Identity", {"b:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("f1/v", "Identity", {"v:0", "^f1/inputs_ready"},
- {{"T", DT_RESOURCE}}, kDevice),
-
- NDef("f1/one", "Const", {"^f1/inputs_ready"},
+ NDef("f1/one", "Const", {"^Func/f1/input_control_node/_0"},
{{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
- NDef("f1/add", "AssignAddVariableOp", {"f1/v", "f1/one"},
+ NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
+ {{"T", DT_FLOAT}}, kDevice),
+ NDef("f1/add", "AssignAddVariableOp", {"Func/f1/input/_3", "f1/one"},
{{"dtype", DT_FLOAT}}, kDevice),
- NDef("f1/mul", "Mul", {"f1/x", "f1/y"}, {{"T", DT_FLOAT}}, kDevice),
- NDef("f1/side_effects_executed", "NoOp", {"^f1/add"}, {}, kDevice),
+ NDef("Func/f1/output/_4", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
+ kDevice),
+ NDef("Func/f1/output_control_node/_5", "NoOp", {"^f1/add"}, {}, kDevice),
// Function body of a second function call also inlined into the graph,
- // and input nodes read directly from the inlined nodes of the first
- // function call.
- NDef("f2/inputs_ready", "NoOp",
- {"^v", "^f1/mul", "^f1/side_effects_executed"}, {}, kDevice),
+ // and input nodes read from the output nodes of the first function call.
+ NDef("Func/f2/input_control_node/_6", "NoOp",
+ {"^Func/f1/output_control_node/_5"}, {}, kDevice),
- NDef("f2/x", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
+ NDef("Func/f2/input/_7", "Identity", // input: 'x'
+ {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
{{"T", DT_FLOAT}}, kDevice),
- NDef("f2/y", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
+ NDef("Func/f2/input/_8", "Identity", // input: 'y'
+ {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
{{"T", DT_FLOAT}}, kDevice),
- NDef("f2/v", "Identity", {"v:0", "^f2/inputs_ready"},
- {{"T", DT_RESOURCE}}, kDevice),
+ NDef("Func/f2/input/_9", "Identity", // input: 'v'
+ {"v", "^Func/f2/input_control_node/_6"}, {{"T", DT_RESOURCE}},
+ kDevice),
- NDef("f2/one", "Const", {"^f2/inputs_ready"},
+ NDef("f2/one", "Const", {"^Func/f2/input_control_node/_6"},
{{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
- NDef("f2/add", "AssignAddVariableOp", {"f2/v", "f2/one"},
+ NDef("f2/add", "AssignAddVariableOp", {"Func/f2/input/_9", "f2/one"},
{{"dtype", DT_FLOAT}}, kDevice),
- NDef("f2/mul", "Mul", {"f2/x", "f2/y"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("f2/mul", "Mul", {"Func/f2/input/_7", "Func/f2/input/_8"},
+ {{"T", DT_FLOAT}}, kDevice),
- NDef("f2/side_effects_executed", "NoOp", {"^f2/add"}, {}, kDevice),
+ NDef("Func/f2/output/_10", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
+ kDevice),
+ NDef("Func/f2/output_control_node/_11", "NoOp", {"^f2/add"}, {},
+ kDevice),
- // Return values read directly from inlined nodes.
- NDef("out_1", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice),
+ // Return values read from inlined output nodes.
+ NDef("out_1", "Identity", {"Func/f2/output/_10"}, {{"T", DT_FLOAT}},
+ kDevice),
NDef("out_2", "ReadVariableOp",
- {"v", "^f1/side_effects_executed", "^f2/side_effects_executed"},
+ {"v", "^Func/f1/output_control_node/_5",
+ "^Func/f2/output_control_node/_11"},
{{"dtype", DT_FLOAT}}, kDevice)},
// Function library.
@@ -757,20 +767,22 @@
GraphDef optimized_graph;
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
+ const string input_x = "Func/c/input/_0";
+ const string input_y = "Func/c/input/_1";
+ const string output_z = "Func/c/output/_2";
+
GraphDef expected = test::function::GDef(
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
// Function must be inlined and `mul` node placed on a requested device,
// and input `Identity` nodes must be colocated with their source nodes.
- NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}, cpu0),
- NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
- cpu0),
- NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
- cpu1),
- NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, cpu1),
+ NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
+ NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
+ NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, cpu1),
+ NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
- NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, cpu0)},
+ NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, cpu0)},
// Function library.
{mul_func});
@@ -809,8 +821,10 @@
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("c", "NoOp", {}, {}, kDevice),
+
// Call function first time.
- NDef("f1", "PartitionedCall", {"a", "b"},
+ NDef("f1", "PartitionedCall", {"a", "b", "^c"},
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
{"Tout", DataTypeSlice{DT_FLOAT}},
{"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
@@ -836,31 +850,49 @@
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("c", "NoOp", {}, {}, kDevice),
+
// Function body of a first function call inlined into the graph.
- NDef("f1/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
- NDef("f1/x", "Identity", {"a:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
+ NDef("Func/f1/input_control_node/_0", "NoOp", {"^c"}, {}, kDevice),
+
+ NDef("Func/f1/input/_1", "Identity", // input: 'x'
+ {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
kDevice),
- NDef("f1/y", "Identity", {"b:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
+ NDef("Func/f1/input/_2", "Identity", // input: 'y'
+ {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
kDevice),
- NDef("f1/mul", "Mul", {"f1/x", "f1/y"}, {{"T", DT_FLOAT}}, kDevice),
- // Control input from `inputs_ready` node is added to ensure correct
- // frame execution.
- NDef("f1/side_effects_executed", "NoOp", {"^f1/inputs_ready"}, {},
+
+ NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
+ {{"T", DT_FLOAT}}, kDevice),
+
+ NDef("Func/f1/output/_3", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
kDevice),
+ // Control input from `input_control_node` node is added to ensure
+ // correct frame execution.
+ NDef("Func/f1/output_control_node/_4", "NoOp",
+ {"^Func/f1/input_control_node/_0"}, {}, kDevice),
// Function body of a second function call also inlined into the graph,
- // and input nodes read directly from the inlined nodes of the first
+ // and input nodes read directly from the output nodes of the first
// function call, and control dependency edge removed.
- NDef("f2/inputs_ready", "NoOp", {"^f1/mul", "^f1/side_effects_executed"},
- {}, kDevice),
- NDef("f2/x", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
- {{"T", DT_FLOAT}}, kDevice),
- NDef("f2/y", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
- {{"T", DT_FLOAT}}, kDevice),
- NDef("f2/mul", "Mul", {"f2/x", "f2/y"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("Func/f2/input_control_node/_5", "NoOp",
+ {"^Func/f1/output_control_node/_4"}, {}, kDevice),
- // Return directly from inlined node of f2.
- NDef("out", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+ NDef("Func/f2/input/_6", "Identity",
+ {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
+ {{"T", DT_FLOAT}}, kDevice),
+ NDef("Func/f2/input/_7", "Identity",
+ {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
+ {{"T", DT_FLOAT}}, kDevice),
+
+ NDef("f2/mul", "Mul", {"Func/f2/input/_6", "Func/f2/input/_7"},
+ {{"T", DT_FLOAT}}, kDevice),
+ NDef("Func/f2/output/_8", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
+ kDevice),
+
+ // Return directly from output node of f2.
+ NDef("out", "Identity", {"Func/f2/output/_8"}, {{"T", DT_FLOAT}},
+ kDevice)},
// Function library.
{mul_func});
@@ -1003,22 +1035,24 @@
NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
// Function body of a first function call inlined into the graph.
- NDef("fn/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
- NDef("fn/x", "Identity", {"a:0", "^fn/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("fn/cond", "Identity", {"b:0", "^fn/inputs_ready"},
- {{"T", DT_BOOL}}, kDevice),
- NDef("fn/switch", "Switch", {"fn/x:0", "fn/cond:0"}, {{"T", DT_FLOAT}},
- kDevice),
- NDef("fn/if_false", "Identity", {"fn/switch:0"}, {{"T", DT_FLOAT}},
+ NDef("Func/fn/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("Func/fn/input/_1", "Identity", {"b"}, {{"T", DT_BOOL}}, kDevice),
+
+ NDef("fn/switch", "Switch", {"Func/fn/input/_0", "Func/fn/input/_1"},
+ {{"T", DT_FLOAT}}, kDevice),
+ NDef("fn/if_false", "Identity", {"fn/switch"}, {{"T", DT_FLOAT}},
kDevice),
NDef("fn/if_true", "Identity", {"fn/switch:1"}, {{"T", DT_FLOAT}},
kDevice),
- NDef("fn/merge", "Merge", {"fn/if_false:0", "fn/if_true:0"},
+ NDef("fn/merge", "Merge", {"fn/if_false", "fn/if_true"},
{{"T", DT_FLOAT}, {"N", 2}}, kDevice),
- // Return directly from inlined node.
- NDef("out", "Identity", {"fn/merge:0"}, {{"T", DT_FLOAT}}, kDevice)},
+ NDef("Func/fn/output/_2", "Identity", {"fn/merge"}, {{"T", DT_FLOAT}},
+ kDevice),
+
+ // Return directly from inlined function output node.
+ NDef("out", "Identity", {"Func/fn/output/_2"}, {{"T", DT_FLOAT}},
+ kDevice)},
// Function library.
{no_dead_outputs});
@@ -1088,22 +1122,25 @@
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
// Inlined inputs of `b` node.
- NDef("b/inputs_ready", "NoOp", {"^a"}, {}, kDevice),
- NDef("b/x", "Identity", {"a:0", "^b/inputs_ready"}, {{"T", DT_FLOAT}},
- kDevice),
+ NDef("Func/b/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
// Inlined inputs of `square` node inside inlined `MySquare` function.
- NDef("b/square/inputs_ready", "NoOp", {"^b/x"}, {}, kDevice),
- NDef("b/square/x", "Identity", {"b/x:0", "^b/square/inputs_ready"},
+ NDef("Func/b/square/input/_2", "Identity", {"Func/b/input/_0"},
{{"T", DT_FLOAT}}, kDevice),
- NDef("b/square/y", "Identity", {"b/x:0", "^b/square/inputs_ready"},
+ NDef("Func/b/square/input/_3", "Identity", {"Func/b/input/_0"},
{{"T", DT_FLOAT}}, kDevice),
// Inlined mul node from the `MyMul` function.
- NDef("b/square/mul", "Mul", {"b/square/x", "b/square/y"},
+ NDef("b/square/mul", "Mul",
+ {"Func/b/square/input/_2", "Func/b/square/input/_3"},
{{"T", DT_FLOAT}}, kDevice),
- NDef("c", "Identity", {"b/square/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+ NDef("Func/b/square/output/_4", "Identity", {"b/square/mul"},
+ {{"T", DT_FLOAT}}, kDevice),
+ NDef("Func/b/output/_1", "Identity", {"Func/b/square/output/_4"},
+ {{"T", DT_FLOAT}}, kDevice),
+
+ NDef("c", "Identity", {"Func/b/output/_1"}, {{"T", DT_FLOAT}}, kDevice)},
// Function library.
{mul_func});
@@ -1155,7 +1192,7 @@
}},
},
/* Mapping between function returns and function node outputs. */
- {{"z", "if_node:output:0"}});
+ {{"z", "if_node:output:0"}}, {{"side_effect", "if_node"}});
// Build a computation graph for:
// is_add: bool
@@ -1179,7 +1216,7 @@
{"f", FDH::FunctionRef("AddOrMul")}},
kDevice),
- NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, kDevice)},
+ NDef("d", "Identity", {"c", "^c"}, {{"T", DT_FLOAT}}, kDevice)},
// Function library.
{add_or_mul_func, add_func, mul_func});
@@ -1888,10 +1925,10 @@
TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) {
using test::function::NDef;
FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
- DisableFunctionSpecialization(&optimizer);
auto func = test::function::XTimesTwo();
(*func.mutable_attr())["_noinline"].set_b(true);
GrapplerItem item;
+ item.id = "test_graph";
item.graph = test::function::GDef(
{NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, "/device:CPU:0"),
NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, "/device:CPU:0"),
@@ -1906,8 +1943,9 @@
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(output.library().function().size(), 1);
- EXPECT_EQ(output.library().function(0).signature().name(), "XTimesTwo");
+ ASSERT_EQ(output.library().function().size(), 1);
+ EXPECT_EQ(output.library().function(0).signature().name(),
+ "XTimesTwo_specialized_for_y_at_test_graph");
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 6e68a53..e07fe78 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -775,6 +776,8 @@
Status added_device = item.AddDevice(d->name());
if (!added_device.ok()) VLOG(3) << added_device.error_message();
}
+ VLOG(3) << "Grappler available devices: "
+ << absl::StrJoin(item.devices(), ", ");
// Add fetches so that the graph can be pruned.
item.fetch.swap(ret_node_names);