tf.cond is used to define if-statement in tensorflow operation graph. It has the following signature:
cond(
pred,
true\_fn=None,
false\_fn=None)
An important point that needs to be noted is that tf.cond evaluates both true\_fn and false\_fn first and then depend on pred it returns one of them. This is also mentioned on Tensorflow API doc but it is not clear at first read.
For example in following piece of code the value of z is 6 (summation of x and y) although we x > 1 is False. But since both operation is evaluation, z=tf.add(x,y) is also calculated and set the result of z to 6.
import tensorflow as tf
x = tf.constant(1)
y = tf.constant(5)
z = tf.get_variable('z', shape=[1])
z = tf.add(x, y)
res = tf.cond(x > 1, lambda: tf.multiply(x, z), lambda: y)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run([res, z])
------
output: [5, 6]
------
If one wants to only execute one of the branches based on a condition, the branch needs to be defined inside the pred. It is described in detail in this stackoverflow post.