Back to Tensorflow

AutoGraph reference

tensorflow/python/autograph/g3doc/reference/operators.md

2.21.010.5 KB
Original Source

AutoGraph reference

Index

Operator semantics

Definition

This section describes the semantics of the operators used in code generated by AutoGraph. Understanding these operators will make it easier to read the generated code.

AutoGraph operators are Python functions that replace certain Python constructs in the generated code.

For example, the following statement:

if x:
  y = 1
else:
  y = 2

Will result in the following generated code:

def get_state():
    return (y,)

def set_state(vars_):
    nonlocal y
    (y,) = vars_

def if_body():
    nonlocal y
    y = 1

def else_body():
    nonlocal y
    y = 2
y = ag__.Undefined('y')
ag__.if_stmt(ag__.ld(x), if_body, else_body, get_state, set_state, ('y',), 1)

In the example above, ag__.if_stmt, ag__.ld and ag__.Undefined are all AutoGraph operators.

The source of truth for these operators is the source code . All public symbols exported by that module is considered an operator.

Type-based dispatch

AutoGraph replaces Python statements with operators in order to enable type-based dispatch. If Python didn't support things like __add__, then AutoGraph would already have an add operator.

Dispatch means simply that the operator does different things based on the type of input.

Generally, the dispatch follows these rules:

  • if the input is a type that would execute normally under Python (this is also referred to as "the default path"), then AutoGraph always reverts to the corresponding Python operator. For example, ag__.not(False) always has the same result as not False.
  • if the input is a TensorFlow type, then AutoGraph typically dispatches to an equivalent TensorFlow API, performs additional checks or just raises an error. For example, ag__.eq(tf.constant(1), tf.constant(2)) has the same result as tf.math.equal(tf.constant(1), tf.constant(2)).

The first rule above means that if you convert normal, non-TensorFlow code with AutoGraph and call it with non-TensorFlow inputs, executing the generated code should be no different than executing the original.

Functional form

All AutoGraph operators use pure functional forms. This may sometimes mean that expressions which normally appear bare in Python, are wrapped inside a function (also known as thunk). If a Python statement appears as just foo, then a corresponding thunk is lambda: foo.

Operator list

Conditional expressions

Source

if_exp

Source

The Python conditional statement: foo if bar else baz.

Args: cond: expression condition; same as cond in _ if cond else _. if_true: true value (as thunk); same as lambda: x in x if _ else _. if_false: false value (as thunk); same as lambda: x in _ if _ else x. expr_repr: human-readable string representing cond. Used for error messages.

Example:

true_val if cond else false_val
ag__.if_expr(cond, lambda: true_val, lambda: false_val, 'cond')

Dispatch on cond:

  • default: to Python if-else statement.
  • tf.Tensor: to tf.cond, checking that true_val and false_val have compatible shape and type.

Control flow

Source

Unlike Python, AutoGraph control flow operators use explicit control flow variables, which include all symbols which are modified by the control flow.

For example, the code below has a single loop variable, x:

while x < 3:
  x = x + 1

In addition, control flow that is dispatched to non-Python implementation is subject to restrictions of the respective implementations. For example, tf.while_loop requires that all loop variables have supported types (e.g. Tensor of consistent shape and dtype).

for_stmt

Source

For loop: for var in target: body, extended with a per-iteration condition to handle early termination (e.g. due to a break).

Args:

  • iter_: iteration target; same as n in for _ in n.
  • extra_test: optional extra per-iteration condition (as thunk).
  • body: loop body (as unary thunk); same as def body(i): <b> in for i in _: <b>.
  • get_state: returns the current value of the loop variables
  • set_state: sets new values into the loop variables
  • symbol_names: human-readable string representing each loop variable. Used for error messages.
  • opts: additional, implementation-specific, keyword arguments.

Example:

for i in range(3):
  j = j + i
def get_state():
    return (j,)

def set_state(vars_):
    nonlocal j
    (j,) = vars_

def loop_body(itr):
    nonlocal j
    i = itr
    j = j + i

ag__.for_stmt(range(3), None, loop_body, get_state, set_state, ('j',), {})

Example (using extra_test):

for i in range(3):
  if i > 2:
    break
  j = j + i
def get_state():
    return (j,)

def set_state(vars_):
    nonlocal j
    (j,) = vars_

def loop_body(itr):
    nonlocal j
    i = itr
    j = j + i

def extra_test():
    return not(i <= 2)

ag__.for_stmt(range(3), extra_test, loop_body, get_state, set_state, ('j',), {})

Dispatch on iter_:

  • default: to Python for loop (accounting for extra_test).
  • tf.Tensor produced by tf.range: to tf.while_loop, removing the tf.range.
  • tf.Tensor, tf.RagedTensor: to tf.while_loop, checking the loop vars for consistency. opts forwarded to tf.while_loop. Iterates over the outermost dimension of the tensor (similar to tf.map_fn).
  • tf.data.Dataset: to tf.data.Dataset.take_while, checking the loop vars for consistency.
  • tf.data.Iterator, tf.distribute.Iterator: to tf.while_loop called on the iterator's get_next_as_optional, checking the loop vars for consistency.
  • tf.distribute.Iterable: to tf.distribute.Iterable.reduce.
if_stmt

Source

If statement: if cond: body else: orelse-body.

Args:

  • cond: if condition; same as cond in if cond.
  • body: true branch (as unary thunk); same as def body(): <b> in if _: <b>.
  • orelse: false branch (as unary thunk); same as def body(): <b> in if _: <b>.
  • get_state: returns the current value of the conditional variables
  • set_state: sets new values into the conditional variables
  • symbol_names: human-readable string representing each conditional variable. Used for error messages.
  • nouts: number of output conditional variables. Not all conditional variables are outputs - some are just inputs. The first nouts values in get_state and set_state are the conditional outputs.

Example:

if k > 1:
  j = j + i
def get_state():
    return (j, i)

def set_state(vars_):
    nonlocal j, i
    (j, i) = vars_

def body():
    nonlocal j, i
    j = j + i

def orelse():
    pass

ag__.if_stmt(k > 1, body, orelse, get_state, set_state, ('j', 'i'), 1)

Dispatch on cond:

  • default: to Python if statement.
  • tf.Tensor: to tf.cond, removing the tf.range.
while_stmt

Source

While loop: while cond: body.

Args:

  • test: loop condition (as thunk); same as def test(): cond in while cond.
  • body: loop body (as thunk); same as def body(): <b> in while _: <b>.
  • get_state: returns the current value of the loop variables
  • set_state: sets new values into the loop variables
  • symbol_names: human-readable string representing each loop variable. Used for error messages.
  • opts: additional, implementation-specific, keyword arguments.

Example:

while j > 10:
  j = j + i
def get_state():
    return (j,)

def set_state(vars_):
    nonlocal j
    (j,) = vars_

def loop_test():
    nonlocal j
    return j > 10

def loop_body():
    nonlocal j
    j = j + i

ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('j',), {})

Dispatch on return type of test:

  • default: to Python while loop.
  • tf.Tensor: to tf.while_loop.

Data structures

Source

list_append

Source

List append operation: l.append(x). Callers should assume that the list argument is modified, if that is possible.

Args:

  • list_: a list-like value.
  • x: value to append to list.

Returns:

  • same as list_, with an appended value.

Example:

l.append(x)
l = ag__.list_append(l, x)

Dispatch on list_:

  • default: to list_.append.
  • tf.Tensor: to tf.raw_ops.tensor_list_push_back.
  • tf.TensorArray: to tf.TensorArray.write.
list_pop

Source

List pop operation: l.pop(i). Callers should assume that the list argument is modified, if that is possible.

Args:

  • list_: a list-like value.
  • i: optional index to remove from.
  • opts: optional, implementation-specific arguments.

Returns:

  • new_list: same as list_, with the value removed
  • x: the value that was removed

Example:

x = l.pop()
l, x = ag__.list_pop(l)

Dispatch on list_:

  • default: to list_.pop.
  • tf.Tensor: to tf.raw_ops.tensor_list_pop_back.
list_stack
ListPopOpts
ListStackOpts
new_list

Exceptions

assert_stmt

Boolean

and_
eq
not_
not_eq
or_

Python built-ins

float_
int_
len_
print_
range_

Slicing

get_item
GetItemOpts
set_item

Variables

ld
ldu
Undefined
UndefinedReturnValue