Expression
abstract class Expression<Value>(x:Value!?, flagConstant:Boolean) < Delay
Abstract interface for evaluating and differentiating expressions.
-
Value Result type.
-
x Evaluated value of the expression (optional).
- flagConstant Is this a constant expression?
Delayed expressions (alternatively: lazy expressions, compute graphs,
expression templates) encode mathematical expressions that can be
evaluated, differentiated, and moved (using Markov kernels). They are
assembled using mathematical operators and functions much like ordinary
expressions, but where one or more operands or arguments are
Random objects. Where an ordinary expression is evaluated
immediately into a result, delayed expressions evaluate to further
Expression
objects.
Simple delayed expressions are trees of subexpressions with Random
or
Boxed objects at the leaves. In general, however, a delayed
expression can be a directed acyclic graph, as subexpressions may be reused
during assembly.
Simple use
Note
Call value()
on an Expression
to evaluate it.
The simplest use case of a delayed expression is to assemble it and then
evaluate it by calling value()
.
Once value()
is called on an Expression
, it and all subexpressions that
constitute it are rendered constant. This particularly affects any
Random
objects in the expression, the value of which can no longer be
altered.
Evaluations are memoized at checkpoints. Each Expression
object that
occurs in the delayed expression forms such a checkpoint. Further calls to
value()
are optimized to retrieve these memoized evaluations rather than
re-evaluated the whole expression.
Advanced use
More elaborate use cases include computing gradients and applying Markov
kernels. Call eval()
to evaluate the expression in the same way as for
value()
, but without rendering it constant. Any Random
objects in the
expression that have not been rendered constant by a previous call to
value()
are considered arguments of the expression.
Use grad()
to compute the gradient of an expression with respect to its
arguments. The gradients are accumulated in the Random
arguments and can
be retrieved from them.
Use set()
to update the value of Random
arguments, then move()
on
the whole expression to re-evaluate it.
Use value()
, not eval()
, unless you are taking responsibility for
correctness (e.g. moving arguments in a manner invariant to some target
distribution, using a Markov kernel). Otherwise, program behavior may lack
self-consistency. Consider, for example:
if x.value() >= 0.0 {
doThis();
} else {
doThat();
}
This is correct usage. Using x.eval()
instead of x.value()
here would
allow the value of x
to be later changed to a negative value, and the
program would lack self-consistency, as it executed doThis()
instead of
doThat()
based on the previous value.
Checkpoints and memoization
eval()
memoizes only at checkpoints defined by Expression
objects. This
reduces memory use. Internally, grad()
benefits from re-evaluating
expressions between checkpoints to memoize intermediate results. It does so
by calling peek()
. These memoized intermediate results are cleared again
as grad()
progresses.
Member Variables
Name | Description |
---|---|
x:Value!? | Memoized result. |
linkCount:Integer | Count of number of parents, set by trace(). |
visitCount:Integer | Counter used to obtain pre- and post-order traversals of the expression graph. |
flagConstant:Boolean | Is this a constant expression? |
Member Functions
Name | Description |
---|---|
isRandom | Is this a Random expression? |
isConstant | Is this a constant expression? |
hasValue | Does this have a value? |
hasGradient | Does this have a gradient? |
rows | Number of rows in result. |
columns | Number of columns in result. |
length | Length of result. |
size | Size of result. |
value | Evaluate and render constant. |
eval | Evaluate. |
move | Re-evaluate, ignoring memos. |
args | Vectorize arguments and gradients. |
grad | Evaluate gradient with respect to arguments. |
peek | Evaluate. |
trace | Trace an expression before calling grad() or move(). |
constant | Render the entire expression constant. |
Member Function Details
args
final function args() -> (Real[_], Real[_])
Vectorize arguments and gradients.
Returns The vectorized arguments and gradients.
columns
final function columns() -> Integer
Number of columns in result.
constant
final override function constant()
Render the entire expression constant.
eval
final function eval() -> Value!
Evaluate.
Returns The result.
grad
final function grad<Gradient>(g:Gradient)
Evaluate gradient with respect to arguments. Clears memos at fine-grain.
- g Upstream gradient.
The expression is treated as a function, and the arguments defined
as those Random
objects in the expression that are not constant.
If the expression encodes
and this particular object encodes one of those functions
x_i = f_i(x_{i-1}), the upstream gradient d
is
grad()
then computes:
and passes the result to the next step in the chain, which encodes
f_{i-1}. The argument that encodes x_0 keeps the final result---it
is a Random
object.
Reverse-mode automatic differentiation is used. The previous call to
eval()
constitutes the forward pass, and the call to grad()
the
backward pass.
Because expressions are, in general, directed acyclic graphs, a counting mechanism is used to accumulate upstream gradients into any shared subexpressions before visiting them. This ensures that each subexpression is visited only once, not as many times as it is used. Mathematically, this is equivalent to factorizing out the subexpression as a common factor in the application of the chain rule. It turns out to be particularly important when expressions include posterior parameters after multiple Bayesian updates applied by automatic conditioning. Such expressions can have many common subexpressions, and the counting mechanism results in automatic differentiation of complexity O(N) in the number of updates, as opposed to O(N^2) otherwise.
hasGradient
final function hasGradient() -> Boolean
Does this have a gradient?
hasValue
final function hasValue() -> Boolean
Does this have a value?
isConstant
final function isConstant() -> Boolean
Is this a constant expression?
isRandom
override function isRandom() -> Boolean
Is this a Random expression?
length
final function length() -> Integer
Length of result. This is synonymous with rows()
.
move
final function move(x:Real[_]) -> Value!
Re-evaluate, ignoring memos. Memoizes at coarse-grain (i.e. Expression objects, not forms).
- x Vectorized arguments.
Returns The result.
peek
final function peek() -> Value!
Evaluate.
Returns The result.
rows
final function rows() -> Integer
Number of rows in result.
size
final function size() -> Integer
Size of result. This is equal to rows()*columns()
.
trace
final function trace()
Trace an expression before calling grad() or move(). This traces through the expression and, for each Expression object, updates the count of its number of parents. This is necessary to ensure correct and efficient execution of grad() and move(), as these counts ensure that each node in the graph is visited exactly once.
value
final function value() -> Value!
Evaluate and render constant.
Returns The result.