tensorflow/python/autograph/g3doc/reference/operators.md
This section describes the semantics of the operators used in code generated by AutoGraph. Understanding these operators will make it easier to read the generated code.
AutoGraph operators are Python functions that replace certain Python constructs in the generated code.
For example, the following statement:
if x:
y = 1
else:
y = 2
Will result in the following generated code:
def get_state():
return (y,)
def set_state(vars_):
nonlocal y
(y,) = vars_
def if_body():
nonlocal y
y = 1
def else_body():
nonlocal y
y = 2
y = ag__.Undefined('y')
ag__.if_stmt(ag__.ld(x), if_body, else_body, get_state, set_state, ('y',), 1)
In the example above, ag__.if_stmt, ag__.ld and ag__.Undefined are all
AutoGraph operators.
The source of truth for these operators is the source code . All public symbols exported by that module is considered an operator.
AutoGraph replaces Python statements with operators in order to enable
type-based dispatch. If Python didn't support things like __add__, then
AutoGraph would already have an add operator.
Dispatch means simply that the operator does different things based on the type of input.
Generally, the dispatch follows these rules:
ag__.not(False) always has the
same result as not False.ag__.eq(tf.constant(1), tf.constant(2)) has the same
result as tf.math.equal(tf.constant(1), tf.constant(2)).The first rule above means that if you convert normal, non-TensorFlow code with AutoGraph and call it with non-TensorFlow inputs, executing the generated code should be no different than executing the original.
All AutoGraph operators use pure functional forms. This may sometimes mean that
expressions which normally appear bare in Python, are wrapped inside a function
(also known as thunk). If a Python statement appears as just foo, then a
corresponding thunk is lambda: foo.
if_expThe Python conditional statement: foo if bar else baz.
Args: cond: expression condition; same as cond in _ if cond else _. if_true:
true value (as thunk); same as lambda: x in x if _ else _. if_false: false
value (as thunk); same as lambda: x in _ if _ else x. expr_repr:
human-readable string representing cond. Used for error messages.
Example:
true_val if cond else false_val
ag__.if_expr(cond, lambda: true_val, lambda: false_val, 'cond')
Dispatch on cond:
tf.cond, checking that true_val and false_val have
compatible shape and type.Unlike Python, AutoGraph control flow operators use explicit control flow variables, which include all symbols which are modified by the control flow.
For example, the code below has a single loop variable, x:
while x < 3:
x = x + 1
In addition, control flow that is dispatched to non-Python implementation is
subject to restrictions of the respective implementations. For example,
tf.while_loop requires that all loop variables have supported types (e.g.
Tensor of consistent shape and dtype).
for_stmtFor loop: for var in target: body, extended with a per-iteration
condition to handle early termination (e.g. due to a break).
Args:
n in for _ in n.def body(i): <b> in for i in _: <b>.Example:
for i in range(3):
j = j + i
def get_state():
return (j,)
def set_state(vars_):
nonlocal j
(j,) = vars_
def loop_body(itr):
nonlocal j
i = itr
j = j + i
ag__.for_stmt(range(3), None, loop_body, get_state, set_state, ('j',), {})
Example (using extra_test):
for i in range(3):
if i > 2:
break
j = j + i
def get_state():
return (j,)
def set_state(vars_):
nonlocal j
(j,) = vars_
def loop_body(itr):
nonlocal j
i = itr
j = j + i
def extra_test():
return not(i <= 2)
ag__.for_stmt(range(3), extra_test, loop_body, get_state, set_state, ('j',), {})
Dispatch on iter_:
extra_test).tf.Tensor produced by tf.range: to tf.while_loop, removing the
tf.range.tf.Tensor, tf.RagedTensor: to tf.while_loop, checking the loop vars
for consistency. opts forwarded to tf.while_loop. Iterates over the
outermost dimension of the tensor (similar to tf.map_fn).tf.data.Dataset: to tf.data.Dataset.take_while, checking the loop vars
for consistency.tf.data.Iterator, tf.distribute.Iterator: to tf.while_loop called on
the iterator's get_next_as_optional, checking the loop vars for
consistency.tf.distribute.Iterable: to tf.distribute.Iterable.reduce.if_stmtIf statement: if cond: body else: orelse-body.
Args:
cond in if cond.def body(): <b> in if _: <b>.def body(): <b> in if _: <b>.Example:
if k > 1:
j = j + i
def get_state():
return (j, i)
def set_state(vars_):
nonlocal j, i
(j, i) = vars_
def body():
nonlocal j, i
j = j + i
def orelse():
pass
ag__.if_stmt(k > 1, body, orelse, get_state, set_state, ('j', 'i'), 1)
Dispatch on cond:
tf.Tensor: to tf.cond, removing the tf.range.while_stmtWhile loop: while cond: body.
Args:
def test(): cond in while cond.def body(): <b> in while _: <b>.Example:
while j > 10:
j = j + i
def get_state():
return (j,)
def set_state(vars_):
nonlocal j
(j,) = vars_
def loop_test():
nonlocal j
return j > 10
def loop_body():
nonlocal j
j = j + i
ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('j',), {})
Dispatch on return type of test:
tf.Tensor: to tf.while_loop.list_appendList append operation: l.append(x). Callers should assume that the list
argument is modified, if that is possible.
Args:
Returns:
Example:
l.append(x)
l = ag__.list_append(l, x)
Dispatch on list_:
list_.append.tf.Tensor: to tf.raw_ops.tensor_list_push_back.tf.TensorArray: to tf.TensorArray.write.list_popList pop operation: l.pop(i). Callers should assume that the list
argument is modified, if that is possible.
Args:
Returns:
Example:
x = l.pop()
l, x = ag__.list_pop(l)
Dispatch on list_:
list_.pop.tf.Tensor: to tf.raw_ops.tensor_list_pop_back.list_stackListPopOptsListStackOptsnew_listassert_stmtand_eqnot_not_eqor_float_int_len_print_range_get_itemGetItemOptsset_itemldlduUndefinedUndefinedReturnValue