tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md
-cluster-tf-ops-by-hostCluster the TensorFlow ops by host so that each function only contains ops placed on the same host
-constant-op-device-assignmentAssign device for tf.Const ops
-convert-tf-control-flow-to-scfConvert TensorFlow control flow to SCF.
This pass can be used for all direct control flow lowerings from the TensorFlow dialect to the SCF dialect.
-prepare-tpu-computation-for-tf-exportPrepare TPU computation to be legal for export to TensorFlow
Prepares TPU computation module attached to _TPUCompileMlir op for TensorFlow graph export by making transformation such as replacing or removing MLIR or XLA specific attributes that are not legal in TensorFlow graph.
-tf-batch-matmul-to-tf-einsumReplace TF BatchMatMul op by TF Einsum op.
-tf-broadcast-foldFold explicit broadcasts into the following operations if they support implicit broadcasting on their operand.
-tf-canonicalize-compile-and-replicate-attributesCanonicalize compilation and replication attributes.
A pass that converts existing compilation and replication attributes into
unified attributes. For example, _tpu_replicate="cluster" in the
following code
%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true, use_spmd_for_xla_partitioning = false} : () -> ()
wll be replaced by _replication_info="cluster" and _xla_compile_device_type="TPU".
%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_replication_info = "cluster", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()
_XlaMustCompile=true in the following code
%outputs_67, %control_68 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %outputs_0) {_XlaMustCompile = true, _collective_manager_ids = [], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\00\0A\07\0A\03TPU\10\02\0A\0E\0A\0ATPU_SYSTEM\10\012\02J\008\01\82\01\05h\01\88\01\01", device = "", executor_type = "", f = @__inference__jit_compiled_convolution_op_1510} : (tensor<4x32x32x8xf32>, tensor<*xf32>) -> tensor<*xf32>
will be replaced by _xla_compile_device_type, with its value set to the value of device.
%outputs_67, %control_68 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %outputs_0) {_collective_manager_ids = [], _read_only_resource_inputs = [], _xla_compile_device_type = "", config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\00\0A\07\0A\03TPU\10\02\0A\0E\0A\0ATPU_SYSTEM\10\012\02J\008\01\82\01\05h\01\88\01\01", device = "", executor_type = "", f = @__inference__jit_compiled_convolution_op_1510} : (tensor<4x32x32x8xf32>, tensor<*xf32>) -> tensor<*xf32>
-tf-convert-to-legacy-compile-and-replicate-attributesConvert unified compilation and replication attributes back to legacy attributes.
This transformation pass converts unified compilation and replication
attributes (_replication_info and _xla_compile_device_type) into legacy
attributes. This ensures the unified attributes do not get exposed outside
of the MLIR bridge with V1 pipeline in some cases. The pass expects to have
either none or both of the unified attributes present in an op for the
conversion to happen. Otherwise it will fail.
For example, _replication_info="cluster" and
_xla_compile_device_type="TPU" in the following code
%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_replication_info = "cluster", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()
wll be replaced by _tpu_replicate="cluster" as follows,
%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true, use_spmd_for_xla_partitioning = false} : () -> ()
-tf-data-optimizationPerforms tf.data optimizations
-tf-decompose-reduce-datasetDecomposes ReduceDataset op into dataset operations.
Decomposes ReduceDataset op into a while loop that iterates the dataset and calls into the reduction function. This decomposition is only done if the ReduceDataset op is marked for compilation with the _xla_compile_device_type attribute.
For example, for the following function the ReduceDataset op:
func.func @single_state_single_dataset_type_no_arguments(
%arg0: tensor<!tf_type.variant>,
%arg1: tensor<i64>
) {
%1 = "tf.ReduceDataset"(%arg0, %arg1) {
Targuments = [],
Tstate = [i64], device = "",
f = @__reduce_func_1, f._tf_data_function = true,
output_shapes = [#tf_type.shape<>],
output_types = [i64], use_inter_op_parallelism = true, _xla_compile_device_type="TPU"} :
(tensor<!tf_type.variant>, tensor<i64>) -> (tensor<i64>)
func.return
}
with the following reduction function:
func.func private @__reduce_func_1(%arg0: tensor<i64> {tf._user_specified_name = "args_0"},
%arg1: tensor<32xf32> {tf._user_specified_name = "args_1"}) -> (tensor<i64>)
attributes {tf._tf_data_function = true, tf.signature.is_stateful} {
%0 = "tf.JustPretend"(%arg1) : (tensor<32xf32>) -> (tensor<i64>)
func.return %0 : tensor<i64>
}
will be transformed into:
func.func @single_state_single_dataset_type_no_arguments(%arg0: tensor<!tf_type.variant>, %arg1: tensor<i64>) {
%0 = "tf.AnonymousIteratorV3"() {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : () -> tensor<!tf_type.resource>
"tf.MakeIterator"(%arg0, %0) : (tensor<!tf_type.variant>, tensor<!tf_type.resource>) -> ()
%cst = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%1:2 = "tf.WhileRegion"(%cst, %arg1) ({
^bb0(%arg2: tensor<i1>, %arg3: tensor<i64>):
"tf.Yield"(%arg2) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<i1>, %arg3: tensor<i64>):
%2 = "tf.IteratorGetNextAsOptional"(%0) {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : (tensor<!tf_type.resource>) -> tensor<!tf_type.variant>
%3 = "tf.OptionalHasValue"(%2) : (tensor<!tf_type.variant>) -> tensor<i1>
%4 = "tf.IfRegion"(%3) ({
%5 = "tf.OptionalGetValue"(%2) : (tensor<!tf_type.variant>) -> tensor<32xf32>
%6 = func.call @__reduce_func_1(%arg3, %5) {_xla_compile_device_type = "TPU"} : (tensor<i64>, tensor<32xf32>) -> tensor<i64>
"tf.Yield"(%6) : (tensor<i64>) -> ()
}, {
"tf.Yield"(%arg3) : (tensor<i64>) -> ()
}) {_lower_using_switch_merge = true, is_stateless = false} : (tensor<i1>) -> tensor<i64>
"tf.Yield"(%3, %4) : (tensor<i1>, tensor<i64>) -> ()
}) {_lower_using_switch_merge = true, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i1>, tensor<i64>) -> (tensor<i1>, tensor<i64>)
return
}
-tf-device-assignment-by-func-attrDevice assignment in TF dialect using the device specified in the function attribute.
-tf-device-cluster-formationForm clusters from instructions assigned to same device
Clusters operations with the same device assignment id. For each cluster, creates a "tf_device.device_launch" op with a Region containing the ops in each cluster and replaces the ops with the new launch op.
For example, given the following program:
%2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
%4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
%5 = "tf.D"(%4) : (tensor<?xi32>) -> tensor<?xi32>
After the pass, we will have:
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%1 = "tf_device.launch"() ( {
%3 = "tf.B"(%0) : (tensor<?xi32>) -> tensor<?xi32>
%4 = "tf.C"(%0, %3) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {device = "tpu0"} : () -> tensor<?xi32>
%2 = "tf.D"(%1) : (tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
-tf-device-cluster-outliningOutlines regions of tf_device.cluster operations
This pass outlines the body of a tf_device.cluster into a function and
replaces the tf_device.cluster op with an equivalent tf_device.cluster_func
op. Implicit operands will be captured and materialized as explicit arguments to
the newly created functions and associated tf_device.cluster_func ops.
For example, the following:
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
will be transformed into:
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor<i32>) -> tensor<i32>
return %cluster : tensor<i32>
}
func @_func(%arg0: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
-tf-device-constant-sinkingSinks constants implicitly captured in a tf_device.cluster region.
This pass sinks implicitly captured constants (tf.Const ops) used by and into
a tf_device.cluster region. Performing this prior to outlining will reduce the
number of arguments of the outlined function.
For example, the following:
func @cluster() -> tensor<i32> {
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
will be transformed into:
func @cluster() -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
-tf-device-convert-launch-func-to-tf-callRewrites tf_device::LaunchFuncOp to TF::PartitionedCallOp
This pass converts tf_device::LaunchFuncOp into an equivalent TF::PartitionedCallOp so that it can be exported to TensorFlow GraphDef.
-tf-device-index-selectorFold tf.DeviceIndex to constant.
-tf-device-launch-outliningOutlines regions of tf_device.launch operations
This pass outlines the body of a tf_device.launch into a function and
replaces the tf_device.launch op with an equivalent tf_device.launch_func
op. Implicit operands will be captured and materialized as explicit arguments to
the newly created functions and associated tf_device.launch_func ops. The
device attribute from the launch op is transferred to launch_func.
For example, the following:
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
%launch = "tf_device.launch"() ( {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {device = "some_device"} : () -> (tensor<i32>)
return %launch : tensor<i32>
}
will be transformed into:
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
%launch = "tf_device.launch_func"(%arg0) {device = "some_device", func = @_func} : (tensor<i32>) -> tensor<i32>
return %launch : tensor<i32>
}
func @_func(%arg0: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
-globally-unique-func-names : If true, the pass adds extra identifiers to make function names globally unique within a process, not just within a module.
-tf-device-mark-input-output-aliasesMarks device cluster inputs-output pairs that read/write to the same variable as aliases
This pass analyzes the inputs and outputs to device cluster and marks those
input-output pairs as aliases (using tf.aliasing_output attribute) which read
and write to the same resource. This aliasing information can then be propagated
to XLA compiler for input/output buffer space optimizations.
-tf-drop-while-shape-invariantDrop shape_invariant attribute from While/WhileRegion ops.
Drop shape_invariant attribute from tf.While and tf.WhileRegion op. This
would allow shape inference pass to further refine operand/result shapes of
these ops. This is only safe to do when compiling to XLA.
-tf-drop-while-shape-invariant-in-device-clusterDrop shape_invariant attribute from While/WhileRegion ops inside device cluster.
Drop shape_invariant attribute from tf.While and tf.WhileRegion op only
inside device cluster. This would allow shape inference pass to further
refine operand/result shapes of these ops. This is only safe to do when
compiling to XLA.
-tf-einsumTransform Einsum to other TF Ops for the supported variants
-tf-embedding-pipeliningRewrite graph for embedding pipelining
For architectures that support accelerated embedding lookups, this pass will rewrite the graph to use pipelining for better device utilization.
-tf-embedding-program-keySets the program key for embedding ops.
Passes in the program key to embedding ops. Will move the embedding ops after a _TPUCompileMlir op if there is no predecessor _TPUCompileMlir op. Both the embedding op and compile op are assumed to be wrapped in separate tf_device.launch() ops. This is because the embedding op is head outside compiled and the compile op is wrapped in launch to execute on host during TPURewritePass.
For example, the tf.OpA with the mini_batch_splits attribute will be
moved after _TPUCompileMlir and the first input will use the
_TPUCompileMlir program output:
"tf_device.launch"() ({
%cst_0 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string>
"tf.OpA"(%cst_0) { mini_batch_splits = ""} : (tensor<1x!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
%0:2 = "tf_device.launch"() ({
%compilation_status, %program = "tf._TPUCompileMlir"() { metadata = "...", mlir_module = "..." } : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
becomes:
%0:2 = "tf_device.launch"() ({
%compilation_status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
"tf_device.launch"() ({
%cst = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string>
"tf.OpA"(%0#1) {mini_batch_splits = ""} : (tensor<3x!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
-tf-embedding-sequencingRewrite graph for sequential execution of embeddings
This is a strictly sequential and formally correct fallback option for the embedding pipelining pass intended for debugging during pipelining development.
-tf-executor-break-up-islandsTransform from TF control dialect to TF executor dialect.
-tf-executor-check-control-dependenciesChecks control dependencies
This pass analyzes control dependencies between islands and warns about dependencies that are not explainable by side effects of the involved ops. More precisely, for every minimal unexplainable control dependency path we emit op warnings for all involved ops. The pass does not report intermediate dummy ops for grouping control dependencies (Identity, NoOp), unless they are part of an unexplainable path between other ops. This pass is useful to understand control dependency conservatism for a given MLIR module.
For example, the following function
func.func @path_with_intermediate_ops(
%arg0: tensor<!tf_type.resource<tensor<f32>>>,
%arg1: tensor<!tf_type.resource<tensor<f32>>>,
%arg2: tensor<f32>) -> () {
tf_executor.graph {
%island1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
%island2 = tf_executor.island(%island1) wraps "tf.NoOp"() : () -> ()
%island3 = tf_executor.island(%island2) wraps "tf.NoOp"() : () -> ()
%island4 = tf_executor.island(%island3) wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
tf_executor.fetch
}
func.return
}
produces the following warnings
6:45: warning: unexpected control dependency path: path 0, node 0 (source)
%island1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
^
6:45: note: see current operation: %control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
7:55: warning: unexpected control dependency path: path 0, node 1 (intermediate)
%island2 = tf_executor.island(%island1) wraps "tf.NoOp"() : () -> ()
^
7:55: note: see current operation: %control_0 = tf_executor.island(%control) wraps "tf.NoOp"() : () -> ()
8:55: warning: unexpected control dependency path: path 0, node 2 (intermediate)
%island3 = tf_executor.island(%island2) wraps "tf.NoOp"() : () -> ()
^
8:55: note: see current operation: %control_1 = tf_executor.island(%control_0) wraps "tf.NoOp"() : () -> ()
9:55: warning: unexpected control dependency path: path 0, node 3 (target)
%island4 = tf_executor.island(%island3) wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
^
9:55: note: see current operation: %control_2 = tf_executor.island(%control_1) wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
because the first and last AssignVariableOps access different resources
and therefore should be independent. Note that the NoOps are considered
as intermediate ops for control dependency grouping.
-tf-executor-convert-control-to-data-outputsChain control outputs of while loop body
This pass converts the control outputs of a while loop body function to data outputs. Thus, inter iteration control dependencies are transformed to data dependencies. Since data dependencies can express which particular operations in the while loop body are dependent on which inputs, it captures inter iteration parallelism in while loop. Control dependencies on the other hand create a barrier at the end of while loop body thus blocking any parallelism across iterations.
For example, the following while loop body has a %barrier at the end.
Although there is no data/control dependency between tf.AssignVariableOp
for %arg0 to tf.AssignVariableOp for %arg1 across any iteration, the
while loop body has a control barrier (%barrier) at the end which forces
a dependency and the two assign variable ops must wait for each other to
complete before starting the next iteration. Transforming these control
outputs to data outputs removes the dependency between the two assign
variable ops, thus allowing them to run in parallel across iterations.
Before:
!tf_res = type tensor<!tf_type.resource<tensor<f32>>>
func @while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor<f32>, %arg3: tensor<f32>) -> (!tf_res, !tf_res, tensor<f32>, tensor<f32>) {
%graph:4 = tf_executor.graph {
%assign_0_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (!tf_res, tensor<f32>) -> ()
%assign_1_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg3) : (!tf_res, tensor<f32>) -> ()
%add_out, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%mul_out, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%barrier = tf_executor.island(%assign_0_control, %assign_1_control, %add_control, %mul_control) wraps "tf.NoOp"() : () -> ()
tf_executor.fetch %arg0, %arg1, %add_out, %mul_out, %barrier : !tf_res, !tf_res, tensor<f32>, tensor<f32>, !tf_executor.control
}
return %graph#0, %graph#1, %graph#2, %graph#3 : !tf_res, !tf_res, tensor<f32>, tensor<f32>
}
After:
func @while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor<f32>, %arg3: tensor<f32>, %chain_0: tensor<i32>, %chain_1: tensor<i32>) -> (!tf_res, !tf_res, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i32>) {
%graph:6 = tf_executor.graph {
%_, %chain_0_src = tf_executor.island wraps "tf.Identity"(%chain_0) : (tensor<i32>) -> tensor<i32>
%_, %chain_1_src = tf_executor.island wraps "tf.Identity"(%chain_1) : (tensor<i32>) -> tensor<i32>
%assign_0_control = tf_executor.island(%chain_0_src) wraps "tf.AssignVariableOp"(%arg0, %arg2) : (!tf_res, tensor<f32>) -> ()
%assign_1_control = tf_executor.island(%chain_1_src) wraps "tf.AssignVariableOp"(%arg1, %arg3) : (!tf_res, tensor<f32>) -> ()
%add_out, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%mul_out, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%chain_0_sink, %_ = tf_executor.island(%assign_0_control) wraps "tf.Identity"(%chain_0) : (tensor<i32>) -> tensor<i32>
%chain_1_sink, %_ = tf_executor.island(%assign_1_control) wraps "tf.Identity"(%chain_1) : (tensor<i32>) -> tensor<i32>
tf_executor.fetch %arg0, %arg1, %add_out, %mul_out, %chain_0_sink, %chain_1_sink : !tf_res, !tf_res, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i32>
}
return %graph#0, %graph#1, %graph#2, %graph#3, %graph#4, %graph#5 : !tf_res, !tf_res, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i32>
}
-tf-executor-graph-pruningPrunes unreachable ops in a tf_executor.graph
This pass removes ops from a tf_executor.graph that are not transitively, via
data or control dependencies, connected to the associated tf_executor.fetch
op. The order of ops will be preserved. Functions named main with no
tf.entry_function attribute will not be pruned, as such graphs/functions may
have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are
not provided at certain stages of IR transformation (e.g. pre-placement).
Option ops-to-preserve allows to specify ops that should not be pruned,
regardless of their reachability.
For example, the following:
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%graph = tf_executor.graph {
%transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
%unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
%reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
%unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
}
return %graph : tensor<i32>
}
will be transformed into:
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%graph = tf_executor.graph {
%transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
%transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
%reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
}
return %graph : tensor<i32>
}
-ops-to-preserve : Comma separated list of ops that should not be pruned regardless of reachability
-tf-executor-island-coarseningWalks tf_executor::GraphOp and merges individual tf_executor::IslandOps.
This pass performs whole graph analysis for a graph encapsulated into tf_executor::GraphOp. The analysis identifies all IslandOps within the graph which could be merged together. The goal is to merge as many islands as possible. Once analysis is completed, the pass merges all IslandOps in a single scan.
For example given the following program with two disjunct islands:
func @test(%arg0 : tensor<i1>) -> tensor<f32> {
%0 = tf_executor.graph {
%1:2 = tf_executor.island {
%3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
tf_executor.yield %3 : tensor<i1>
}
%2:2 = tf_executor.island(%1#1) {
%4 = "tf.opB"() : () -> tensor<f32>
tf_executor.yield %4 : tensor<f32>
}
tf_executor.fetch %2#0 : tensor<f32>
}
return %0 : tensor<f32>
}
After running this pass, the two islands are merged:
func @test(%arg0: tensor<i1>) -> tensor<f32> {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island {
%1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
%2 = "tf.opB"() : () -> tensor<f32>
tf_executor.yield %2 : tensor<f32>
}
tf_executor.fetch %outputs : tensor<f32>
}
return %0 : tensor<f32>
}
-tf-executor-split-into-island-per-opTransform from TF control dialect to TF executor dialect.
Splits an island with multiple ops into multiple islands (one per op). Does not create any control dependencies between new islands, and does not propagate control dependencies that potentially existed between the old islands into the new islands. Maintains existing data dependencies between ops wrapped by the new islands.
Example: original program:
func.func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
%graph:2 = tf_executor.graph {
%island1:3 = tf_executor.island {
%add1 = "tf.Add"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%add2 = "tf.Add"(%add1, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%res = "tf.Print"(%add2) { message = "add result" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield %add1, %add2 : tensor<*xi32>, tensor<*xi32>
}
tf_executor.fetch %island1#0, %island1#1 : tensor<*xi32>, tensor<*xi32>
}
func.return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32>
}
will be converted by this pass into:
func.func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
%0:2 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Add"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.Add"(%outputs, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%outputs_2, %control_3 = tf_executor.island wraps "tf.Print"(%outputs_0) {message = "add result"} : (tensor<*xi32>) -> tensor<*xi32>
tf_executor.fetch %outputs, %outputs_0 : tensor<*xi32>, tensor<*xi32>
}
return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32>
}
-tf-executor-to-functional-conversionLifts tf_executor.island inner ops from a tf_executor.graph
This pass converts tf_executor.graphs consisting of only tf_executor.islands and a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control flow ops are present in a tf_executor.graph, an error will be returned.
For example, the following:
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%graph_results:2 = tf_executor.graph {
%island_0_result, %island_0_control = tf_executor.island {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_executor.yield %identity : tensor<i32>
}
%island_1_result, %island_1_control = tf_executor.island {
%identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
tf_executor.yield %identity_n#0
}
tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
}
return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
}
will be transformed into:
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
}
-tf-executor-tpu-v1-island-coarseningMerges TPU clusters IslandOps, intended for V1 compatibility mode
This pass is a variant of ExecutorIslandCoarseningPass that is limited to TPU-annotated operations and intended to preserve backward compatibility with TFv1.
-tf-executor-tpu-v1-island-inliningInline calls to the nested TPU module.
This pass inlines the islands calling into the nested module that was
outlined, thus reversing the effect of the
-tf-executor-tpu-v1-island-outlining pass.
For example, the following:
module {
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {f = @_tpu_v1_compat_outlined::@bar} : (tensor<f32>) -> tensor<f32>
tf_executor.fetch %outputs : tensor<f32>
}
return %0 : tensor<f32>
}
module @_tpu_v1_compat_outlined {
func nested @bar(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.opA"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
}
will be transformed into:
module {
func @foo(%arg0: tensor<f32>) -> tensor<f32> {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island {
%1 = "tf.opA"(%arg0) : (tensor<f32>) -> tensor<f32>
tf_executor.yield %1 : tensor<f32>
}
tf_executor.fetch %outputs : tensor<f32>
}
return %0 : tensor<f32>
}
}
-tf-executor-tpu-v1-island-outliningOutline TPU clusters from island into a nested module, so it can be processed like a V2 module, intended for V1 compatibility mode
Extract the islands containing a TPU cluster computation into an outlined function in a nested module. This will allow to run the usual bridge on this nested module which now exhibits a more friendly "V2-like" structure. This is only intended for V1 compatibility mode where the bridge runs without feed/fetches on session create/extend.
So given e.g.
func @test() -> tensor<i32> {
%0 = tf_executor.graph {
%output, %control = tf_executor.island {
...
tf_executor.yield %result : tensor<i32>
}
tf_executor.fetch %output : tensor<i32>
}
return %0
}
This pass will create an additional function containing the code in tf_executor.island:
func nested @_tpu_v1_compat_outlined_func0() -> tensor<i32> {
...
}
and will then replace the island with the wrapped call:
func @test() -> tensor<i32> {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.PartitionedCall"() {
f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0
} : () -> tensor<i32>
tf_executor.fetch %outputs : tensor<i32>
}
return %0 : tensor<i32>
}
-tf-executor-update-control-dependenciesComputes and applies all necessary control dependencies based on side effect analysis.
This pass is intended to run after the split_into_island_per_op pass. That pass splits up multi-op islands into multiple individual islands wrapping a single op without applying any control deps between the new islands. So, this pass is needed in order to make preservation of the semantic ordering relationships between ops as determined by side effect analysis explicit in the IR.
Example: original program:
func.func @example(%arg0: tensor<*x!tf_type.resource<tensor<32xf32>>>, %arg1: tensor<32xf32>) -> (tensor<32xf32>) {
%graph = tf_executor.graph {
%read0, %read0_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
%assign0_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg1) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
%read1, %read1_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
%print, %print_control = tf_executor.island wraps "tf.Print"(%read1) { message = "read1 value" } : (tensor<32xf32>) -> (tensor<32xf32>)
tf_executor.fetch %read1#0 : tensor<32xf32>
}
func.return %graph : tensor<32xf32>
}
will be converted by this pass into:
func.func @example(%arg0: tensor<*x!tf_type.resource<tensor<32xf32>>>, %arg1: tensor<32xf32>) -> tensor<32xf32> {
%0 = tf_executor.graph {
%read0, %read0_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
%assign0_control = tf_executor.island(%read0_control) wraps "tf.AssignVariableOp"(%arg0, %arg1) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
%read1, %read1_control = tf_executor.island(%assign0_control) wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
%print, %print_control = tf_executor.island(%read1_control) wraps "tf.Print"(%read1) {message = "read1 value"} : (tensor<32xf32>) -> tensor<32xf32>
tf_executor.fetch %read1, %print_control : tensor<32xf32>, !tf_executor.control
}
return %0 : tensor<32xf32>
}
-tf-extract-head-tail-outside-compilationExtracts head or tail outside compilation to separate host launches before/after device cluster.
This pass extracts a CPU computation cluster with _xla_outside_compilation
annotation from the head or tail of a Device cluster.
For example:
%cluster = "tf_device.cluster"() ( {
%a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
%b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
%c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
tf_device.return %c : tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
return %cluster : tensor<i32>
becomes:
%0 = "tf_device.launch"() ( {
%3 = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
%1 = "tf_device.cluster"() ( {
%3 = "tf.B"(%0) : (tensor<i32>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device_assignment = [], num_cores_per_replica = 1 : i64, padding_map = [], step_marker_location = "", topology = ""} : () -> tensor<i32>
%2 = "tf_device.launch"() ( {
%3 = "tf.C"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
return %2 : tensor<i32>
-tf-extract-outside-compilationExtracts device outside compilation computation to a separate tf_device.parallel_execute region.
This pass extracts a CPU computation cluster with _xla_outside_compilation
annotation, which denotes ops that should be run on CPU/host, from a device cluster.
Each outside compilation cluster is moved to
a tf_device.parallel_execute region. The device cluster is also moved to a
tf_device.parallel_execute region. Communication ops between device and host are
added to pass inputs/outputs to/from the outside compiled region.
For example, the following tf_device.cluster with an op marked for xla_outside_compilation:
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
%2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
%3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
tf_device.return %3 : tensor<f32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
return %0 : tensor<f32>
}
will become a tf_device.parallel_execute op with a CPU/host region and a tf_device.cluster with communication ops to send data to/from device/host:
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.parallel_execute"() ( {
"tf_device.launch"() ( {
%1 = "tf._XlaCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string>
%2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf_type.string>) -> tensor<f32>
%3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
"tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
tf_device.return
}, {
%1 = "tf_device.cluster"() ( {
%2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
%4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tf_device.return %4 : tensor<f32>
}) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
tf_device.return %1 : tensor<f32>
}) : () -> tensor<f32>
return %0 : tensor<f32>
}
-tf-extract-tpu-copy-with-dynamic-shape-opExtract the TPUCopyWithDynamicShapeOp out of the host launch and place it on device launch
This pass looks for TPUCopyWithDynamicShapeOp which wraps in a
tf_device.launch with host device attribute. It extracts the ops and wrap
them in tf_device.launch with tpu device attribute so that ops can be
run on TPU instead of CPU while still being compiled on host.
-tf-functional-control-flow-to-cfgTransform functional control flow Ops to MLIR Control Form Graph (CFG) form
-tf-functional-control-flow-to-regionsTransforms functional control flow operations to their region-based counterparts
This pass transforms functional control flow operations in the TensorFlow
dialect to their region-based counterparts, i.e., tf.If is transformed to
tf.IfRegion and tf.While is transformed to tf.WhileRegion.
For example, this functional operation
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
will be transformed into this region-based operation
%0 = "tf.IfRegion"(%arg0) ( {
%1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
%1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
-tf-functional-to-executor-conversionTransform from func op to TF executor dialect.
-tf-fused-kernel-matcherMatches computations corresponding to optimized fused kernels
-tf-gpu-op-fusionFusion optimization for GPU targets
This pass is performing fusion specific to GPU targets. This is an ad-hoc pass for now, but should be integrated with some notion of "target" in the MLIR pipeline in the future.
-tf-group-by-dialectGroups ops into functions that only contain one dialect.
Factors operations into subroutines such that all functions only contain a single dialect. Which of the dialects are allowed in the "top" function is configurable.
For example, the code x.a() x.b() %c = y.c() x.d(%c) would be transformed into something like call @x_1() %c = call @y_1() call @x_2(%c) with @x_1, @x_2 and @y_1 filled in.
-tf-guarantee-all-funcs-one-useGuarantee all FuncOp's have only a single use.
-tf-hoist-loop-invariantHoists loop invariant ops to the outside of the loop
Hoists loop invariant to the outside of the loop. The pass is similar to LoopInvariantCodeMotion pass, but it also hoists ReadVariableOps, if the variable is read only.
For example, the following pseudo MLIR code (types are left out for brevity)
func.func @hoist_loop_invariant(%arg0, %arg1) {
%var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"}
%results:2 = "tf.WhileRegion"(%arg0, %arg1) ({
^bb0(%arg2, %arg3):
%0 = "tf.OpA"() {is_stateless = true}
"tf.Yield"(%0)
}, {
^bb0(%arg2, %arg3):
%1 = "tf.ReadVariableOp"(%var)
%2 = "tf.OpB"(%1) {is_stateless = true}
%3 = "tf.OpC"(%arg2, %2) {is_stateless = true}
%4 = "tf.OpD"(%arg3, %2) {is_stateless = true}
"tf.Yield"(%3, %4)
}) {is_stateless = true}
return %results#0, %results#1
}
would be transformed to
func.func @hoist_loop_invariant(%arg0, %arg1) {
%var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"}
%1 = "tf.ReadVariableOp"(%var)
%2 = "tf.OpB"(%1) {is_stateless = true}
%results:2 = "tf.WhileRegion"(%arg0, %arg1) ({
^bb0(%arg2, %arg3):
%0 = "tf.OpA"() {is_stateless = true}
"tf.Yield"(%0)
}, {
^bb0(%arg2, %arg3):
%3 = "tf.OpC"(%arg2, %2) {is_stateless = true}
%4 = "tf.OpD"(%arg3, %2) {is_stateless = true}
"tf.Yield"(%3, %4)
}) {is_stateless = true}
return %results#0, %results#1
}
The tf.ReadVariableOp and tf.OpB can be hoisted to the outside of
the loop.
-tf-hoist-replicate-invariant-resource-writesHoists writes to replicate invariant resource variables.
This pass hoists replicate invariant resource variable writes outside tf_device.replicate op. These may have been inserted by other passes such as resource op lifting. However, if the resource variable is not replicated, writes to such variables for each replica are redundant and can be replaced by writing a single value from first replica.
The benefit of this optimization is reduced memory requirement on host. For multiple writes (one from each replica) to such variables, the host would allocate buffer space to receive the device output from all replicas, which is not required. We can use the output of first replica in such cases.
-tf-init-text-file-to-importConvert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op to remove the dependency on asset files
-tf-saved-model-dir : Directory containing the model exported as a TensorFlow SavedModel. If your model is not based on the TensorFlow SavedModel, use an empty value.
-tf-layout-assignmentLayout assignment pass.
-force-data-format : Force data format for all layout sensitive ops.
-tf-localize-var-handlesCreates VarHandleOps next to the operations that use them.
Creates VarHandleOps right next to the operations that use them, one per operation. This is useful for transformations that only end up with a few small snippets of remaining TF code, and wish for those snippets to be self-contained. For example, this would transform
"tf_saved_model.global_tensor"() { sym_name = "v" ... } func @f(%arg0 {tf_saved_model.bound_input = @v}) { %1 = "tf.ReadVariableOp"(%arg0) ... }
to
func @f(%arg0 {tf_saved_model.bound_input = @v}) { %0 = "tf.VarHandleOp"(sym_name = "v") %1 = "tf.ReadVariableOp"(%0) ... }
Note that this pass might leave behind unused values (like e.g. %arg0 in the example above), which can later be pruned using DCE.
-tf-lower-quantizedLowers ops that require quantized input or output.
This pass rewrites all ops that have at least one input or output that must be a quantized type to ops whose inputs and outputs allow non-quantized types. Examples of quantized types are TF_Qint8 or TF_Quint8.
An example is TF_DequantizeOp, which converts a quantized type to a float. This op is rewritten to generic ops that perform the scale and shift and can operate on non-quantized types.
Currently, TF_DequantizeOp is the only op with a lowering that falls in this category. When more lowerings are added (e.g. QuantizeV2Op), they should be added to this pass.
-tf-mark-ops-for-outside-compilationMarks ops in device cluster for outside compilation if they are unsupported on device.
This pass marks unsupported ops in a device cluster with
_xla_outside_compilation attribute so the operations will run on the host
instead of the device. Unsupported ops are ops that can not be code
generated to run on the device for the cluster including:
This pass is conservative in that it will mark all ops for outside compilation that can not be compiled for the device. Exceptions for this are added for ops that will be rewritten or decomposed before compiling on device.
For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.UnsupportedOp"() : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
return %0 : tensor<i32>
}
will mark tf.UnsupportedOp with _xla_outside_compilation attribute:
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
return %0 : tensor<i32>
}
-tf-materialize-passthrough-opMaterialize the MlirPassthroughOp by replacing it with the MLIR module attached as an attribute
A pass that replaces MlirPassthrough ops with the code they have in
their mlir_module string attribute.
-tf-merge-control-flowMerges IfRegion ops together with a common predicate.
This pass merges IfRegion ops together if they have the same predicate and it is safe to do so (there are no intermediate dependencies, they are in the same block, etc).
For example:
"tf.IfRegion"(%0) ( {
%2 = "tf.A"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = true } : (tensor<i1>) -> ()
"tf.IfRegion"(%0) ( {
%2 = "tf.B"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = true } : (tensor<i1>) -> ()
Would be transformed to:
"tf.IfRegion"(%0) ( {
%2 = "tf.A"() : () -> (tensor<f32>)
%3 = "tf.B"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = true } : (tensor<i1>) -> ()
-tf-move-transposesMove transposes pass.
-fold-transpose-in-ops : Whether to fold transposes in ops which can support folding.
-direction : Move transposes to the beginning or the end of the block where they are defined.
-tf-name-anonymous-iteratorsConverts anonymous iterators to named iterators
This converts AnonymousIterator ops to Iterator, thus giving them a name. For example, this will convert %0 = "tf.AnonymousIteratorV3"() {...} to %0 = "tf.Iterator"() {shared_name = "_iterator1", ...}
-tf-optimizeOptimize TensorFlow module
-tf-order-by-dialectReorders ops so ops of the same dialect are next to each other.
Performs a reordering of ops so that (a) ops of the same dialect are next to each other (b) order within a dialect is preserved . For example, this would transform %a = "x.f"() %b = "y.f"(%a) %c = "x.f"(%a) to %a = "x.f"() %c = "x.f"(%a) %b = "y.f"(%a) so that the two "x" dialect instructions are next to each other.
-tf-outside-compiled-to-host-launch_Wraps each op with the xla_outside_compiled attribute in a separate tf_device.launch on replicated host device.
This pass wraps ops with the same _xla_outside_compilation
attribute value in a tf_device.launch op with host device assignment.
A simple example:
"tf_device.cluster"() ( {
"tf.A"()
"tf.B"() {_xla_outside_compilation = "cluster1"}
"tf.C"()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []}
Would become the following ops (unimportant attribute, type are omitted):
"tf_device.cluster"() ( {
"tf.A"()
"tf_device.launch"() {
"tf.B"() {_xla_outside_compilation = "cluster1"}
tf_device.return
} {device = "TPU_REPLICATED_HOST_0"} : () -> ()
"tf.C"()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []}
-tf-parallel-execute-to-islandsLowers device parallel_execute to executor islands
-legacy-graph-export : Determines whether or not this pass should execute logic that is reserved for the legacy graph export pipeline to maintain expected invariants. In the case of this pass, that means manually propagating controls to lifted parallel execute regions to the graph fetch to ensure the ops execute.
-tf-promote-resources-to-argsPromote resources reads/writes to function inputs/outputs.
This pass promotes resource accesses in function(s) (by default, the main) to input arguments and outputs of the function(s).
Two types of resources are supported: (1) A function argument of TF::ResourceType type (this pass). (2) A VarHandleOp in the function (tf-promote-var-handles-to-args).
After the pass,
. The function will have an input argument for each resource that is already provided as an input argument or is read. The type of the input argument will become the shape of the value represented by the resource.
. The function will have an output for each resource that is written. The type of the output will become the shape of the resource.
The information of variable identification and input-output alising is recorded as named attributes of the input argument or output:
. 'tf.resource_name' matches 'shared_name' of VarHandleOp, which represents the identifier of the corresponding resource. This attribute is added to an input argument if the initial value of the resource is read, or to the output if the initial value is not read.
. 'tf.aliasing_output' is the index of the function output that is an alias of the input argument. This attribute is added only to the input argument when the initial value of the corresponding resource is read, and the resource is written later.
Assumption of this pass: . Compound resource operations have already been decomposed. . Dead functions have already been removed, as resource arguments in dead functions can cause the pass to fail.
-functions : Comma separated list of functions whose resources read/writes should be promoted to function inputs/outputs.
-tf-promote-var-handles-to-argsPromote tf.VarHandleOps to function arguments.
See joint description in promote resources to args.### -tf-readonly-references-to-resources
Convert readonly reference variables to resource variables.
-tf-region-control-flow-to-functionalTransforms region-based control flow operations to their functional counterparts
This pass transforms region-based control flow operations in the TensorFlow
dialect to their functional counterparts, i.e., tf.IfRegion is transformed to
tf.If and tf.WhileRegion is transformed to tf.While.
For example, this region-based operation
%0 = "tf.IfRegion"(%arg0) ( {
%1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
%1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
will be transformed into this functional operation
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
-tf-remove-unused-argumentsRemoves unused args from private functions & their callers.
Removes arguments from functions that aren't used in the function body, outside of returns. Also adjusts the callers of said functions.
For example, the code func.func @f(%arg0, %arg1) { SomeOpThatUsesArg0(%arg0) return %arg0 } ... call @x_1(x, y)
would be transformed into func.func @f(%arg0) { return %arg0 } ... call @x_1(x)
Note that, in the above example, both args would be removed if there wasn't the "SomeOpThatUsesArg0(%arg0)" line.
-tf-remove-unused-while-resultsRemoves unused results from tf.WhileRegion ops
Removes unused results from tf.WhileRegion ops along with the defining
ops in the body, if it is safe to do so.
Currently, the pass detects results with following properties:
tf.WhileRegion optf.Identity)For example, the following pseudo MLIR code (types are left out for brevity)
func.func @remove_first_result(%arg0, %arg1) {
%0:2 = "tf.WhileRegion"(%arg0, %arg1) ({
^bb0(%arg2, %arg3):
%1 = "tf.OpA"() {is_stateless = true}
"tf.Yield"(%1)
}, {
^bb0(%arg2, %arg3):
%1 = "tf.OpB"(%arg2) {is_stateless = true}
%2 = "tf.OpC"(%arg3) {is_stateless = true}
"tf.Yield"(%1, %2)
}) {is_stateless = true}
return %0#1
}
would be transformed to
func.func @remove_first_result(%arg0, %arg1) {
%0 = "tf.WhileRegion"(%arg1) ({
^bb0(%arg3):
%1 = "tf.OpA"() {is_stateless = true}
"tf.Yield"(%1)
}, {
^bb0(%arg3):
%1 = "tf.OpC"(%arg3) {is_stateless = true}
"tf.Yield"(%1)
}) {is_stateless = true}
return %0
}
(the first result can be removed along with its defining op tf.OpB).
-tf-replica-id-to-device-ordinalSet device ordinal with replica id
This pass sets the device ordinal attribute of the ops using the replica id attribute. This is run immediately after the replica_to_island pass which sets the replica id attribute of these ops. Note for single chip usecase, the pass will check if there is one op and sets the device ordinal attribute to be zero.
-tf-replicate-invariant-op-hoistingHoists replicate invariant operations out of replicate
This pass looks for replicate invariant ops in a tf_device.replicate op
region and hoists them out. It also makes tf.Shape ops replicate invariant
if possible. This currently updates or replaces tf.Shape ops of replicated
arguments, either tensors or resources.
The primary benefit of the pass is to hoist num_replicas _TPUCompiles
into a single _TPUCompile.
This pass assumes that when a tf.Shape directly inputs from replicate
params, then it is the same shape across replicas.
For example, the following
tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
%2 = "tf.Shape"(%ri) : (tensor<*xi32>) -> tensor<?xi32>
tf_device.return
}
gets converted to
tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
%2 = "tf.Shape"(%0) : (tensor<*xi32>) -> tensor<?xi32>
tf_device.return
}
and for resource variables the following
tf_device.replicate([%0, %1] as %ri: tensor<*x!tf_type.resource>) {n = 2 : i32} {
%2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf_type.resource> -> tensor<*xi32>
%3 = "tf.Shape"(%2) : (tensor<*xi32>) -> tensor<?xi32>
tf_device.return
}
gets converted to
tf_device.replicate([%0, %1] as %ri: tensor<*x!tf_type.resource>) {n = 2 : i32} {
%2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf_type.resource> -> tensor<*xi32>
%3 = "tf.VariableShape"(%0) : (tensor<*x!tf_type.resource>) -> tensor<?xi32>
tf_device.return
}
-tf-replicate-tensor-list-init-opsReplicate TensorList init ops for correct shape assignments in shape inference
If we pass same TensorList to a while op as multiple arguments or just use the same TensorList at multiple places and assign different TensorListSetItem to elements of TensorList, the shape inference is then unable to identify the Shape of these args and thus the input TensorList shape is unidentifiable. All of these args are supposed to be independent and not related to original creation of TensorList.
This pass will create multiple instances of TensorList for each arg of the while op and each use and thus there will be not a conflict in resolving the shape of these different inputs.
-tf-replicate-to-islandLowers device replicate to executor islands
-legacy-graph-export : Determines whether or not this pass should execute logic that is reserved for the legacy graph export pipeline to maintain expected invariants. In the case of this pass, that means manually propagating controls to lifted parallel execute regions to the graph fetch to ensure the ops execute, as well as determining whether or not the islands created by this pass should be split after the replicated ops have been lifted.
-tf-resource-device-inferencePropagates the device attribute on resources from callers to callees.
A pass that propagates device assignment of resources on a module. It performs in-function propagation, as well as cross-function propagation from callers to callees.
This pass changes the module by adding "tf.device" attribute to function arguments and adding "device" attribute to TF ops.
For example, given the function
!tf_res = type tensor<*x!tf_type.resource<tensor<32xf32>>>
func @test(%arg0: !tf_res {tf.device = "/TPU:0"}) {
tf_executor.graph {
%control = tf_executor.island {
%id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res
tf_executor.yield
}
tf_executor.fetch %control : !tf_executor.control
}
return
}
Observe how the op inside the island obtains a /TPU:0 device assignment:
!tf_res = type tensor<*x!tf_type.resource<tensor<32xf32>>>
func @test(%arg0: !tf_res {tf.device = "/TPU:0"}) {
tf_executor.graph {
%control = tf_executor.island {
%0 = "tf.Identity"(%arg0) {device = "/TPU:0"} : (!tf_res) -> !tf_res
tf_executor.yield
}
tf_executor.fetch %control : !tf_executor.control
}
return
}
-tf-rewrite-tpu-embedding-opsRewrites TPU embedding send/recv ops by adding TPU embedding deduplication data
-tf-shape-inferenceShape inference on TF dialect and ops implementing InferTypeOpInterface
Fixed point shape refinement pass that utilizes the shape functions registered on ops using the InferTypeOpInterface as well as by bridging to the TensorFlow op registry's shape functions. This is an interprocedural pass that propagates information across function calls/control flow operations where possible (the GuaranteeAllFuncsOneUsePass is often run before this pass to enable more propagation opportunities). It refines both the outermost element type of tensors as well as the nested component type (e.g., for tensor lists).
During shape refinement this pass may insert additional cast operations as well as fold some constant shape computations to enable more exact shape inference. Therefore it does do some mutation of the graph. Constant folding required to produce more exact shapes is also performed but these values are only kept in the context rather than the ops folded/IR mutated.
-max-iterations : Maximum shape inference iterations
-tf-simple-device-assignmentSimple device assignment in TF dialect.
Assigns the default device to all ops that have an empty (or nonexistent) device attribute.
For example, if we have the code
%0 = "tf.Const"() {value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%1 = "tf.Const"() {device = "", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%2 = "tf.Const"() {device = "baz", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
then running this pass with 'default-device=foobar', we get:
%0 = "tf.Const"() {device = "foobar" value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%1 = "tf.Const"() {device = "foobar", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%2 = "tf.Const"() {device = "baz", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-default-device : The default device to assign.
-tf-stack-ops-decompositionDecompose stack operations into local variable operations. Needs static shapes.
A pass that converts stack operations to tensor operations and read/assign ops on local variables. A later resource lifting pass can further remove the local variables.
This pass requires that the full shape of the stack can be inferred: 1) the maximum size needs to be a constant and 2) a push op can be found with a known shape, and all push ops need to have the same shape.
A stack creation op "tf.StackV2" will be turned in to two zero-initialized variables, for the buffer and current size. Each push will be turned into
%old_val = "tf.ReadVariableOp"(%buffer)
%old_size = "tf.ReadVariableOp"(%size)
%offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
%new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
"tf.AssignVariableOp"(%buffer, %new_val)
%new_size = "tf.AddV2"(%old_size, %const1)
"tf.AssignVariableOp"(%size, %new_size)
and each pop will be turned into
%old_val = "tf.ReadVariableOp"(%buffer)
%old_size = "tf.ReadVariableOp"(%size)
%new_size = "tf.Sub"(%old_size, %const1)
%offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
%slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
%pop_result = "tf.Reshape"(%slice, %elem_size_const)
"tf.AssignVariableOp"(%size, %new_size)
The pass also works across control flow and functional calls.
-tf-strip-noinline-attribute_Strip the tf.noinline attribute from top-level functions.
-tf-strip-tf-attributesRemoves TF specific attributes
Removes attributes that are TF specific (start with "tf.") or that have a value from the TF dialect. Useful after legalizing TF graphs to other dialects, to remove any TF remnants.
-tf-tensor-array-ops-decompositionDecompose tensor array operations into local variable operations.
A pass that converts tensor array operations to tensor operations and read/assign ops on local variables. A later resource lifting pass can further remove the local variables.
This pass requires that the full shape of the tensor array can be inferred:
-tf-tensor-device-copyFold the tf.Identity op and the tf.IdentityN op if the op has the same device as its operand
-tf-tensor-list-ops-decompositionDecomposes TensorList operations into generic operations on tensors.
This pass rewrites TensorList operations into generic and non-mutating operations on tensors. This results in operations that can be legalized to XLA.
The list is converted to a single large tensor that includes all list elements, with a new first dimension for the list index. List update operations are converted to operations that create a new tensor representing the list.
In the current implementation, the resulting operations are statically shaped,
which means it must be possible to infer a bound on the full shape of the
TensorList. That is, the element_shape and num_elements arguments to a
tensor list creation op are constant.
A tensor list creation op tf.EmptyTensorList/tf.TensorListReserve will be
turned in to a zero-initialized buffer, and the size is initialized to 0
for tf.EmptyTensorList or the specified size for tf.TensorListReserve.
Each push will be turned into tf.XlaDynamicUpdateSlice with the incremented
size, and each pop will be turned into a tf.Slice and a copy of the buffer
with decremented size. Each tf.TensorListSetItem will be turned into a
tf.XlaDynamicUpdateSlice with unchanged size, and each tf.TensorListGetItem
will be rewritten to a tf.Slice.
The pass also works across control flow and functional calls.
For example, the TensorList ops in the following function:
func @main(%arg0: tensor<8x4xf32>) {
%elem_shape = "tf.Const"() {value = dense<[8, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<2xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<8x4xf32>>>
%push = "tf.TensorListPushBack"(%tl, %arg0) : (tensor<!tf_type.variant<tensor<8x4xf32>>>, tensor<8x4xf32>) -> tensor<!tf_type.variant<tensor<8x4xf32>>>
return
}
will be transformed to:
func @main(%arg0: tensor<8x4xf32>) {
// EmptyTensorList lowering
%emptyi = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%emptyf = "tf.Cast"(%emptyi) : (tensor<i32>) -> tensor<f32>
%size_shape = "tf.Const"() {value = dense<[10, 8, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
%tl = "tf.BroadcastTo"(%emptyf, %size_shape) : (tensor<f32>, tensor<3xi32>) -> tensor<10x8x4xf32>
// TensorListPushBack lowering
%index_in_list = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%arg0_shape = "tf.Const"() {value = dense<[1, 8, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
%arg0_reshaped = "tf.Reshape"(%arg0, %arg0_shape) : (tensor<8x4xf32>, tensor<3xi32>) -> tensor<1x8x4xf32>
%zeroi2 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
%axis = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%start_indices = "tf.ConcatV2"(%index_in_list, %zeroi2, %axis) : (tensor<1xi32>, tensor<2xi32>, tensor<i32>) -> tensor<3xi32>
%push = "tf.XlaDynamicUpdateSlice"(%tl, %arg0_reshaped, %start_indices) : (tensor<10x8x4xf32>, tensor<1x8x4xf32>, tensor<3xi32>) -> tensor<10x8x4xf32>
%one = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%next_index_in_list = "tf.AddV2"(%index_in_list, %one) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
return
}
-tf-tpu-annotate-dynamic-shape-inputsAnnotate the inputs returned by TPUCopyWithDynamicShapeOp with dynamic shape
This pass looks for the usage of the result of TPUCopyWithDynamicShapeOp and sets the shape of these inputs to be dynamic shaped. This will ensure that the generated HLO program is correctly reflecting the dynamic shape.
-tf-tpu-cleanup-cluster-attributes_Eliminate replication_info and other attributes from ops in a cluster
This pass eliminate _replication_info and device attribute on operations
that are contained in a tf_device.cluster op.
-tf-tpu-cluster-formationForms clusters from operations assigned to the same TPU computation
TPU computations from the frontend are composed of a tf.TPUReplicateMetadata
op, a subgraph of ops (TensorFlow Dialect) each with a matching
_replication_info attribute relative to the associated
tf.TPUReplicateMetadata op, and optionally tf.TPUReplicatedInput and
tf.TPUReplicatedOutput ops feeding in inputs and outputs to and from a
replicated TPU computation. The number of times a TPU computation is
replicated is defined in the tf.TPUReplicateMetadata op (num_replicas
attribute) and operand and result sizes of tf.TPUReplicatedInput and
tf.TPUReplicatedOutput respectively must match, excluding packed tensors.
It is also assumed ops of the same TPU computation do not have ops outside
of the TPU computation that are both inputs and outputs to the same TPU
computation. Furthermore, we assume that every node has either none or both
of _replication_info and _xla_compile_device_type attributes defined.
This pass takes the TPU computation subgraph, moves them into a
tf_device.cluster, and copies over attributes from the associated
tf.TPUReplicateMetadata op to the newly created tf_device.cluster. If the
computation is replicated (num_replicas > 1), the num_replicas attribute is
not copied over but instead the tf_device.cluster is further wrapped with a
tf_device.replicate, and associated tf.TPUReplicatedInput and
tf.TPUReplicatedOutput ops are replaced as the tf_device.replicate operands
and results. Otherwise, the single operands and results of the associated
tf.TPUReplicatedInput and tf.TPUReplicatedOutput ops are simply forwarded to
the tf_device.cluster.
For example, the following non replicated computation:
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
// Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
// with topology `<topology>`.
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
%replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
%identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
%replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
return %replicated_output : tensor<i32>
}
will be transformed into:
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
The following replicated computation:
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
%replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
%replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
}
will be transformed into:
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
tf_device.return %cluster : tensor<i32>
}
return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
}
-tf-tpu-colocate-composite-resource-opsColocate resource with composite device assignment to TPU device.
Pass that co-locates resource ops that use composite device resources (packed tensors) with the underlying physical TPU device.
So for example, if we have a function that does (inside a tf_device.replicate):
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
Then said ReadVariableOp is going to get replaced by:
%0 = "tf_device.launch"() ( {
%2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
tf_device.return %2 : tensor<4xf32>
}) {...} : () -> tensor<4xf32>
-tf-tpu-colocate-splitsColocates each Split op with its predecessor
It is beneficial for performance to assign a Split op to the same device
as its predecessor. This is because the weight of cut edges is always
minimized when the Split is with its predecessor. This colocation
constraint will be used by the placer graph optimization to assign a device
to the op.
This pass should run in the export pipeline after tf-replicate-to-island so each replica has its own distinct (predecessor, Split) pair.
The colocation class (_class) of the Split is set to the same class as
its predecessor:
%outputs1:2, %control1 = tf_executor.island wraps "tf.IteratorGetNext"(%arg)
{_class = ["loc:@dataset_iterator_1"]}
%outputs2:2, %control2 = tf_executor.island wraps "tf.Split"(%outputs0, %outputs1#1)
{_class = ["loc:@dataset_iterator_1", num_split = 2 : i32}
-tf-tpu-device-propagationPropagates TPU devices from ops to users
-tf-tpu-dynamic-layout-passInserts TPU layout ops to determine layout at run time.
A pass that allows TPU input layout to be determined after JIT compilation. This is done by adding run-time ops that interpret compilation result and copy the input to device with that layout.
Example: original program:
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
%compile:2 = "tf._TPUCompileMlir"(...)
%execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
Without this pass, later TF graph partitioning passes will insert send/recv between %input and %execute and data will be copied to device in a fixed layout. With this pass, the program will be transformed into:
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
%compile:2 = "tf._TPUCompileMlir"(...)
%get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
%copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
{device = "/TPU:0"}
%execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
{device = "/TPU:0"}
This way, %compile will determine the layout, which will be respected by %copy_to_device. There will not be send/recv ops added by later passes, because tf.TPUCopyWithLayout accepts a host input and produces a device output.
-tf-tpu-host-computation-expansionExpands host computation before and after TPU computation.
This pass expands outside compilation attributes to Identity/Cast ops at the head of TPU computation if it's only used by outside compiled ops.
-tf-tpu-identity-pruningRemoves Identity/IdentityN ops from the TPU computation
-tf-tpu-merge-variables-with-executeMerges device variable reads and updates into TPU execute ops
This pass finds on-device resource variable reads and updates surrounding a
tf.TPUExecute op and merges them into a tf.TPUExecuteAndUpdateVariables
op. This allows the TPU execution to perform more efficient in-place
variable updates.
For example,
%0 = "tf.ReadVariableOp"(%arg0)
%1 = "tf.ReadVariableOp"(%arg1)
%2 = "tf.TPUExecute"(%0, %1, %compile)
%3 = "tf.AssignVariableOp"(%arg0, %2)
will be transformed into
%2 = "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %compile)
{ device_var_reads_indices = [0, 1],
device_var_updates_indices = [0, -1] }
The transformation happens only for on-device variables. The above
transformation requires %arg0, %arg1 to have the same device assignment
as the TPUExecute op.
-tf-tpu-parallel-execute-sink-resource-writeMoves tf.AssignVariableOp consumers of tf_device.parallel_execute into tf_device.parallel_execute regions
-tf-tpu-partitioned-op-conversionRewrite all TPU Partitioned ops into their V2 counterparts.
-tf-tpu-reorder-replicate-partitioned-inputsReorder replicated and partitioned input ops.
This pass rewrites how data parallelism and model parallelism is expressed for
inputs. It reorders tf.TPUPartitionedInput (model parallelism) and
tf.TPUReplicatedInput (data parallelism) ops. It transforms a DAG where
multiple tf.TPUPartitionedInput ops are feeding into a single
tf.TPUReplicatedInput into a DAG where multiple tf.TPUReplicatedInput ops
are feeding into a single tf.TPUPartitionedInput. Transforming the IR in such
a manner will allow subsequent cluster formation pass to handle IR with both
data and model parallelism in an easier manner.
For example, the following:
!rtype = type tensor<!tf_type.resource<tensor<10x3xf32>>>
func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype {
%pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
%pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
%ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (!rtype, !rtype) -> !rtype
return %ri : !rtype
}
will be transformed into:
!rtype = type tensor<!tf_type.resource<tensor<10x3xf32>>>
func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype {
%ri_0 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype
%ri_1 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype
%pi = "tf.TPUPartitionedInput"(%ri_0, %ri_1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
return %pi : !rtype
}
-tf-tpu-resource-partitionPartitions unpartitioned resource read/write to partitioned resource variables.
This pass creates individual resource reads/writes from the unpartitioned
resource variable (from tf.TPUPartitionedInput) to individual partitioned
resource variables (tf.TPUPartitionedInput operands). As resource op
decomposition/lifting occurs with the unpartitioned resource variables,
transforming the IR in such a manner will allow for subsequent passes to operate
on individual resource variable handles per core/device.
For example, the following:
func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
%partitioned_variable = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
%read = "tf.ReadVariableOp"(%partitioned_variable) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%read) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%partitioned_variable, %computation) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
return %arg0: tensor<i32>
}
will be transformed into:
func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
return %arg0: tensor<i32>
}
-tf-tpu-resource-read-for-writeInserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads
This pass materializes tf.ReadVariableOp inputs to an outlined TPU computation
for resource variables where only writes are present so later in the pipeline
such resource variables can be fused with generated tf.TPUExecute ops, which
only supports resource variable read or read + write. For all TPU computations,
resource variables are required to be initialized prior to execution. Write only
resource variable uses can be generated currently via packed tensor uses.
For example, the following:
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf_type.resource<tensor<i32>>>) {
%0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @cluster(%arg0: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
will be transformed into:
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf_type.resource<tensor<i32>>>) {
%resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
-tf-tpu-rewriteRewrites a tf_device.cluster_func on TPUs into TPU runtime operations.
This pass rewrites a tf_device.cluster_func operation into a sequence of tf._TPUCompileMlir
and tf.TPUExecute operations. tf._TPUCompileMlir contains a MLIR module that is
functionally equivalent to the function referenced by tf_device.cluster_func.
This makes the module to be jit-compiled and executed on TPU.
If it is not possible to rewrite the operation or device assignment fails,
a failure will be returned.
Note, many parameters to the tf_device.cluster_func are omitted in this
and following examples.
For example, a non replicated tf_device.cluster_func:
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
return
}
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
%0:2 = "tf_device.launch"() ( {
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
%1 = "tf_device.launch"() ( {
%2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf_type.string>) -> tensor<i8>
tf_device.return %2 : tensor<i8>
}) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
return
}
A replicated tf_device.cluster_func:
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
%1 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
tf_device.return %1 : tensor<i8>
}
return
}
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
%0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
%1:2 = "tf_device.launch"() ( {
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
%2 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf_type.string>) -> tensor<i8>
tf_device.return %3 : tensor<i8>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
tf_device.return %2 : tensor<i8>
}
return
}
A non replicated tf_device.cluster_func with the model parallelism:
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0:3 = "tf_device.launch"() ( {
%compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>, tensor<3x!tf_type.string>)
tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf_type.string>, tensor<3x!tf_type.string>, tensor<3x!tf_type.string>
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>, tensor<3x!tf_type.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
%1 = "tf_device.parallel_execute"() ( {
%2 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf_type.string>) -> tensor<8xi32>
tf_device.return %3 : tensor<8xi32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
tf_device.return %2 : tensor<8xi32>
}, {
"tf_device.launch"() ( {
"tf.TPUExecute"(%0#2) : (tensor<3x!tf_type.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
tf_device.return
}) : () -> tensor<8xi32>
return %1 : tensor<8xi32>
}
-tpu-compile-metadata-debug : Whether to serialize TPUCompileMetadataProto metadata in 'tf._TPUCompileMlir' op as a proto debug string
-tf-tpu-sharding-identificationIdentifies and handles inputs/outputs of TPU computation that is sharded across logical cores.
Bubbles up sharding configuration from cluster_func regions into
the attributes of cluster_func. This is done by parsing the
XlaSharding / TPUPartitionedOutput / TPUPartitionedInput ops inside
cluster_func.
For example, given the following cluster_func wrapping func:
func @test(%arg0: tensor<*xi32>) {
"tf_device.cluster_func"(%arg0) {
func = @func,
step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
return
}
func @func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03",
sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
return %1 : tensor<*xi32>
}
Now, cluster_func receives the following *_sharding_configuration
attributes, and func receives the mhlo.sharding attribute:
func @test(%arg0: tensor<*xi32>) {
%0 = "tf_device.cluster_func"(%arg0) {
func = @func,
input_sharding_configuration = ["\01\02\03"],
output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"],
step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
return
}
func @func(%arg0: tensor<*xi32> {mhlo.sharding = "\01\02\03"}) ->
(tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03", sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.A"(%0) : (tensor<*xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
-tf-tpu-space-to-depth-passApplies automatic space to depth transform for the first or frontier convolutions consume host inputs on TPU.
Automatic space to depth transform is done by adding space to depth transform op after host input and applying space to depth transform for the first convolution and its backprop filter on TPU.
For example, original program:
module {
func @while_body {
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}: -> tensor<2x224x224x3xf32>
%device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...)
return ...
}
func @_func(%input: tensor<2x224x224x3xf32>, %filter: tensor<7x7x3x64xf32>) {
%6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}: (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32>
}
}
The program will be transformed into:
module {
func @while_body {
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"} -> tensor<2x224x224x3xf32>
%space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}: (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%device_launch = "tf_device.cluster_func"(%space_to_depth,...) {func = @_func,...)
return ...
}
func @_func(%input: tensor<2x112x112x12xf32>, %filter: tensor<7x7x3x64xf32>) {
%filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter): tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
%conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}: (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32>
}
}
This way, the first convolution with 3 feature dimension will be transformed to 12 feature dimension, which has better performance on TPU.
-tf-tpu-update-embedding-enqueue-op-inputsUpdates inputs to TPU embedding enqueue ops depending on whether graph is in training mode or in evaluation mode.
Updates inputs to TPU embedding enqueue ops depending on whether graph is in training mode or in evaluation mode.
-tf-tpu-validate-inputsValidates inputs to the TPU TF/XLA bridge
This pass checks that the IR has valid input to TPU TF/XLA bridge. It checks the relations of multiple ops. Properties of single ops are checked by the 'verify' method of ops.
-tf-tpu-variable-runtime-reformattingAdds device variable formatting op to allow compilation-guided variable formatting.
A pass that takes advantage of a loop to add ops that allow the execution to avoid repeatedly formatting variables back and forth. The desired formatting is determined by TPU program compilation, so this pass does not include how to reformat the variables, but only inserts general TPUReshardVariablesOps in proper places, and TPUReshardVariablesOps interpret the compilation.
The core idea of this optimization is to keep track of the formatting state of variables, and when the next desired state does not change, it can avoid reformatting. We associate a set of variables on a device with a formatting state, and TPUReshardVariablesOps compares the current state with a desired state (which can be the compilation result). If they mismatch, TPUReshardVariablesOp reformats the variables to the desired state; if they match, TPUReshardVariablesOp is a no-op.
A major use of this pass is weight-update sharding in data parallelism, so we require there is a tf_device.replicate in the loop.
For example, suppose we have a training loop (for simplicity we write the loop body inine):
%var0 = ...
%var1 = ...
tf.while (..., %var0, %var1) {
tf_device.replicate ([%var0, %var1] as %rvar) {
%compile:2 = "tf._TPUCompileMlir"()
tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
}
}
This pass will transform it into
%var0 = ...
%var1 = ...
%state_var0 = ...
%state_var1 = ...
tf.while (..., %var0, %var1, %state_var0, %state_var1) {
tf_device.replicate ([%var0, %var1] as %rvar,
[%state_var0, %state_var1] as %rstate) {
%compile:2 = "tf._TPUCompileMlir"()
tf.TPUReshardVariablesOp(%rvar, %compile#1, %rstate)
tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
}
}
%default_format = tf.constant()
tf_device.replicate ([%var0, %var1] as %rvar,
[%state_var0, %state_var1] as %rstate) {
tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
}
-tf-unroll-batch-matmulUnroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.
-tf-verify-for-exportVerify module is suitable for export back to TF Graph
Verifies whether all functions in module are of single tf_executor.graph and each tf_executor.island in tf_executor.graph only has a single op.
-tf-xla-call-module-deserializationDeserializes StableHLO functions embedded in tf.XlaCallModule to top level module
This pass deserializes the StableHLO bytecodes embedded in tf.XlaCallModule, then outlines the functions in the deserialized StableHLO module to the top level MLIR module, with function renamings to avoid naming conflicts.
After the outlining, it updates tf.XlaCallModule's module attribute to be
empty, adds an _entry_function attribute referring to the entry function.
It also adds a _from_xla_call_module: true attribute to each lifted
StableHLO function.
-tf-xla-call-module-serializationSerializes StableHLO functions from top-level module into tf.XlaCallModule's module attribute
This pass collects StableHLO functions referenced from tf.XlaCallModule's
_entry_function attribute into a module, serializes the module into MLIR
bytecode, and embed the bytecode to tf.XlaCallModule's module attribute.
After serialization, this pass removes the _entry_function attribute from
tf.XlaCallModule, and removes all the serialized stablehlo functions
from the top-level module.
-tfe-legalize-tfgLegalize from TFG to the TFE dialect