tf_1.8_xla_doc
Functions
tensorflow::anonymous_namespace{tf2xla.cc} Namespace Reference

Functions

Status RewriteAndPruneGraph (Graph *graph, const tf2xla::Config &config, const std::unordered_map< string, string > &feed_remapping)
 
Status ConvertGraphToXla (std::unique_ptr< Graph > graph, xla::Client *client, xla::Computation *computation)
 Converts the TensorFlow graph into an XLA computation, by executing the graph symbolically, with each op building up the XLA HLO. More...
 
Status InitGraph (const GraphDef &graph_def, const tf2xla::Config &config, std::unique_ptr< Graph > *graph)
 InitGraph creates a graph based on the graph_def, that may then be convert to an xla::Computation via ConvertGraphToXla. More...
 

Function Documentation

◆ ConvertGraphToXla()

Status tensorflow::anonymous_namespace{tf2xla.cc}::ConvertGraphToXla ( std::unique_ptr< Graph >  graph,
xla::Client *  client,
xla::Computation *  computation 
)

Converts the TensorFlow graph into an XLA computation, by executing the graph symbolically, with each op building up the XLA HLO.

  1. [UNCLEAR] tensorflow::XlaOpRegistry::RegisterCompilationKernels()
  2. Traverse nodes and record the index of assigned device name (hard code with "/device:XLA_CPU_JIT")
  3. Generate XLA Arguments by _Arg nodes (Ensure the index and type attrs of each nodes initialized correctly first)
  4. Compile graph to XLA UserComputation
    1. Set options for object XlaCompiler
    2. Call CompileGraph, method of object XlaCompiler
  5. Check compilation result. Throw error if there's a generated function returns constant value (Result of invalid config).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ InitGraph()

Status tensorflow::anonymous_namespace{tf2xla.cc}::InitGraph ( const GraphDef &  graph_def,
const tf2xla::Config &  config,
std::unique_ptr< Graph > *  graph 
)

InitGraph creates a graph based on the graph_def, that may then be convert to an xla::Computation via ConvertGraphToXla.

Google Doc:

The graph is rewritten with _Arg and _Retval nodes, representing the inputs and outputs of the function that will be compiled. Each feed id causes a new _Arg node to be created, where we first collect all existing edges pointing from the named node's output index, and then rewrite them to point from that _Arg node instead. Each fetch id causes a new _Retval node to be created, with a new edge pointing from the named node's output index to that _Retval node.

  1. tensorflow::ValidateConfig()
  2. Generate mapping for user defined function.
    1. gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>> function_defs_
    2. gtl::FlatMap<string, string> func_grad_
  3. tensorflow::AddPlaceholdersForFeeds()
  4. tensorflow::PruneGraphDefInfo()
  5. tensorflow::AddDefaultAttrsToGraphDef()
  6. tensorflow::ConvertGraphDefToGraph()
  7. tensorflow::RewriteAndPruneGraph()
Here is the call graph for this function:
Here is the caller graph for this function:

◆ RewriteAndPruneGraph()

Status tensorflow::anonymous_namespace{tf2xla.cc}::RewriteAndPruneGraph ( Graph *  graph,
const tf2xla::Config &  config,
const std::unordered_map< string, string > &  feed_remapping 
)
  1. Generate a _Arg node and remove the original placeholder node by removing edge and add a new edge to _Arg node.
  2. For each fetch tensor, create a _Retval node and add it to the end.
  3. Do BFS from _Retval nodes and remove all unreachable nodes. Thus removing placeholders creating by tensorflow::AddPlaceholdersForFeeds()
  4. Connect nodes without input edge with Source node by using control edge. And Sink node to nodes without output edge.

Google Docs:

RewriteAndPruneGraph identifies input and output edges (named by the feed and fetch ids respectively), and rewrites the edges so that inputs flow from _Arg nodes, and outputs flow to _Retval nodes. This allows the symbolic graph execution to know the input and output args for the generated function.

Here is the caller graph for this function: