Expressing Stateful Ops and Mutability in TOSA
This document describes a collection of utility operations to express statefulness around the TOSA specification. The goals of the operators are to express statefulness, scoping and mutability, while aligning with existing frontend and backend infrastructure supporting statefulness.
These operators are not TOSA spec impacting. They are designed primarily to work within the capabilities and constraints of the TOSA MLIR implementation or a similar SSA IR, and are intended to be solely compiler-visible utility operations.
Memory Model
A tensor
is an abstract and immutable value comprising metadata describing its shape and element datatype, and the data.
The course of execution of the neural network consumes tensors, generating new tensors as the output of individual operations.
A symbolic reference (or symref
) of a tensor
is a reference to the object holding a tensor. It interacts with tensor
values through read/write semantics. The read operation returns the latest tensor value that was most recently written by a write.
Each tensor
is allocated discretely in non-aliasing memory and thus a tensor
can have only a single symref
to it .
The proposal assumes that two sets of memory regions are implemented for read/write content:
- A procedure-scoped stack-like memory
- A process-scoped heap-like memory
Operators
tosa.variable
This operator defines a symbolic reference to a tensor allocated in a compiler-allocated persistent heap-like memory, or an analoous solution. The operator definition defines the name, shape and datatype.
The syntax of this operator is:
tosa.variable @sym_ref_name : <tensortype>
For example:
tosa.variable @myvar : tensor<1xf32>
Purpose:
- Compile time definition of liveness within the current subgraph scope
- Defines size for allocation in persistent memory at compile time.
Notes:
- Scope is implicit from location of definition - within enclosing subgraph.
- Does not emit any functional code, just allocates memory and defines addressing and linkage for subsequent references.
- The allocation region is expected to be initialized to 0.
- tensortype defines the tensor metadata - shape and element dtype. The shape is not required to be static; it may be inferred as part of a dynamic shape resolution infrastructure, e.g. Extending TOSA to Support Symbolic Shapes - #14 by sjarus .
tosa.variable.read
This operator defines that a read will be performed from the location defines by the symbolic reference.
The operator does not perform the actual read operation. It simply defines that the SSA value-carrying literal will access this tensor from its location in persistent memory.
In comparison, regular SSA value-carrying literal are allocated in (stack-like) procedure scoped memory.
The syntax for this operator is:
%ssa_id = tosa.variable.read @sym_ref_name : <tensortype>
For example:
%1 = tosa.variable.read @ctr : tensor<13x21x3xi32>
Purpose:
- Defines that the SSA literal accesses persistent state, from the location whose addressing and linkage is previously defined by the tosa.variable op.
- The tensor type information enables type and bounds validation.
Notes:
- Semantics of the read are dictated by the compiler. For example, it may perform the read by copying the latest value of the tensor object from memory, or it may present something similar to an rvalue-reference.
tosa.variable.write
This operator defines that a write will be performed to the location defines by the symbolic reference.
The operator does not perform the actual write operation. It simply defines that the SSA value-carrying literal will perform a write to this tensor in its location in persistent memory.
The syntax for this operator is:
tosa.variable.write @sym_var_name, %ssa_id : <tensortype>
For example:
tosa.variable.write @cell_state, %22, : tensor<24xf32>
Purpose:
- Defines that the SSA literal defines a write to persistent state, at the location whose addressing and linkage is previously defined by the tosa.variable op.
- The tensor type information enables type and bounds validation.
Notes
- As with the case of tosa.variable.read, the semantics of the operation are driven by the compiler.
Open Questions
-
Can
symrefs
only refer to tensors in persistent memory ? Can they refer to procedure-scoped (I/O and intermediate tensors) or immutable weights ? Perhaps a lower level IR can choose to do so, but not at TOSA level of abstraction ? -
Can non-aliasing behavior be restricted to program side behavior, and not runtime behavior ? Some RNN implementations involve the persistent memory being periodically cleared the runtime when the process is quiet. In this case, the non-aliasing guarantee of read/write semantics does not hold.