Statefulness Support for TOSA

Expressing Stateful Ops and Mutability in TOSA

This document describes a collection of utility operations to express statefulness around the TOSA specification. The goals of the operators are to express statefulness, scoping and mutability, while aligning with existing frontend and backend infrastructure supporting statefulness.

These operators are not TOSA spec impacting. They are designed primarily to work within the capabilities and constraints of the TOSA MLIR implementation or a similar SSA IR, and are intended to be solely compiler-visible utility operations.

Memory Model

A tensor is an abstract and immutable value comprising metadata describing its shape and element datatype, and the data.

The course of execution of the neural network consumes tensors, generating new tensors as the output of individual operations.

A symbolic reference (or symref) of a tensor is a reference to the object holding a tensor. It interacts with tensor values through read/write semantics. The read operation returns the latest tensor value that was most recently written by a write.

Each tensor is allocated discretely in non-aliasing memory and thus a tensor can have only a single symref to it .

The proposal assumes that two sets of memory regions are implemented for read/write content:

  • A procedure-scoped stack-like memory
  • A process-scoped heap-like memory

Operators

tosa.variable

This operator defines a symbolic reference to a tensor allocated in a compiler-allocated persistent heap-like memory, or an analoous solution. The operator definition defines the name, shape and datatype.

The syntax of this operator is:

tosa.variable @sym_ref_name : <tensortype>

For example:

tosa.variable @myvar : tensor<1xf32>

Purpose:

  • Compile time definition of liveness within the current subgraph scope
  • Defines size for allocation in persistent memory at compile time.

Notes:

  • Scope is implicit from location of definition - within enclosing subgraph.
  • Does not emit any functional code, just allocates memory and defines addressing and linkage for subsequent references.
  • The allocation region is expected to be initialized to 0.
  • tensortype defines the tensor metadata - shape and element dtype. The shape is not required to be static; it may be inferred as part of a dynamic shape resolution infrastructure, e.g. Extending TOSA to Support Symbolic Shapes .

tosa.variable.read

This operator defines that a read will be performed from the location defines by the symbolic reference.

The operator does not perform the actual read operation. It simply defines that the SSA value-carrying literal will access this tensor from its location in persistent memory.

In comparison, regular SSA value-carrying literal are allocated in (stack-like) procedure scoped memory.

The syntax for this operator is:

%ssa_id = tosa.variable.read @sym_ref_name : <tensortype>

For example:

%1 = tosa.variable.read @ctr : tensor<13x21x3xi32>

Purpose:

  • Defines that the SSA literal accesses persistent state, from the location whose addressing and linkage is previously defined by the tosa.variable op.
  • The tensor type information enables type and bounds validation.

Notes:

  • Semantics of the read are dictated by the compiler. For example, it may perform the read by copying the latest value of the tensor object from memory, or it may present something similar to an rvalue-reference.

tosa.variable.write

This operator defines that a write will be performed to the location defines by the symbolic reference.

The operator does not perform the actual write operation. It simply defines that the SSA value-carrying literal will perform a write to this tensor in its location in persistent memory.

The syntax for this operator is:

tosa.variable.write @sym_var_name, %ssa_id : <tensortype>

For example:

tosa.variable.write @cell_state, %22, : tensor<24xf32>

Purpose:

  • Defines that the SSA literal defines a write to persistent state, at the location whose addressing and linkage is previously defined by the tosa.variable op.
  • The tensor type information enables type and bounds validation.

Notes

  • As with the case of tosa.variable.read, the semantics of the operation are driven by the compiler.

Open Questions

  1. Can symrefs only refer to tensors in persistent memory ? Can they refer to procedure-scoped (I/O and intermediate tensors) or immutable weights ? Perhaps a lower level IR can choose to do so, but not at TOSA level of abstraction ?

  2. Can non-aliasing behavior be restricted to program side behavior, and not runtime behavior ? Some RNN implementations involve the persistent memory being periodically cleared the runtime when the process is quiet. In this case, the non-aliasing guarantee of read/write semantics does not hold.

@stellaraccident and @rsuderman, please take a look when possible. I’d be happy to address any feedback.

Hi Suraj, a few comments.

I think this is going to be one of those things that isn’t going to age well. Also above, it says that the initialization value can be provided optionally, but that seems to contradict this (and there seems to be no syntax for the optional initialization). What were you thinking for initialization (in IREE, we support both an explicit initialization value or a reference to an initializer function).

We also support dynamically shaped variables and if not too much of a burden, it would be nice if these ops were dynamic shape tolerant from the get-go. As an example of such an initializer:

tosa.variable @foo : tensor<?x4xf32> = dense<[1]> : tensor<1x4xf32>

Nit: I would probably put the symbol ref first in the asm.

This and the similar statement for read sound funny to me and when discussing with someone else, we were having trouble parsing what the actual memory model is. This stuff is a bit fiddly because, at this level:

  • tensor is an abstract, immutable value, containing metadata (shape and dtype) and contents
  • variable denotes a “reference” to a tensor. Because tensor is immutable, it means that it contains a point-in-time snapshot of a tensor's parameters (shape/dtype/contents).
  • The semantics of read/write are such that a read will yield a tensor (shape/dtype/contents) equal to the value previously written via write (or initializer if not write has not yet been performed).

There are various ways to lower from tensors to something more concrete. A way to describe this in terms of the physical semantics you reference would be in the form of a valid implementation which:

  • Implemented every tensor in the program as a discrete, non-aliased region of memory for metadata+contents.
  • Variable read/write have load/store semantics, causing a copy of the operand (both metadata and contents) to be transferred in/out of the memory backing the variable.

More complicated lowerings that alias, when safe to do so, are possible (and a reality in any production compiler).

I think we were (possibly?) misreading the physical description to imply more that the variable contained a pointer to a memory region. That would involve buffer/pointer semantics at this level, which would further imply aliasing – which is forbidden at this level. Looking at the op definitions, this doesn’t seem to be what is being defined, but the English description left a question to us.

Yes this is a vestige of going back and forth on this, and not proofreading the final form. The genesis here is that originally an optional initial value was considered, but the sharp edge here is how do runtimes (e.g. TFL or TFLu) manage the handling of a region that under control of both runtime and the executing network ? Adding @Eric for further input.
AI: fix language to remove contradictions.

Yeah the proposal doesn’t object to shape dynamism here - that would be expected to be supported by compiler/runtime dynamic heap management.
The overall semantic here is that certain use cases may process a fully shape resolved form based on the compiler/runtime support, and other more advanced stacks may carry dynamic shapes.
AI: Add language to state shape need not be compile time defined.

Sure, sounds reasonable.
AI: swap symbol ref and SSA_ID

No I think this is a matter of simply unclear language and lack of precision regarding the memory model. This proposal ventures into a new area beyond just an op set to at least implicitly requiring some abstract memory model concept here.

The goal here is to have something that makes the widest possible set of implementation approaches and memory models possible, from something that is the most feasible on a very constrained embedded setup, to a ‘full fat’ platform, using potentially multiple compiler stacks that define TOSA - from the MLIR infrastructure to a simpler use-case focused stack.

I’ll take another pass at this for clarity, but the language here overloads terms in a manner that suggest specific rules for implementation that it doesn’t demand.

Yeah, I’m not sure I’m tracking exactly. I think it is pretty straight-forward to say that variables can have an initial value and leave that to conversions in/out to adapt as needed. If you are picking an initial value (0 with a statically defined shape), it is strictly more flexible to let that initial value be customized – and it solves the problem of metadata that is more generic than the value.

I’ve updated the original post for more clarity. A few further notes:

  • I tried to express tensors and symrefs - I sought to avoid using variable as a term here due to overloading concerns - using the analogy of rvalues and lvalues but decided against that level of detail.
  • I’ve more explicitly noted dynamic shapes, referencing your proposal in the process.

Couldn’t this be finessed by the compiler ? As with C++03 copy semantics vs newer, it could optimize by some mechanism analogous to rvalue-referencing. This was what the original text described, but which I removed in favor of ‘it’s up to the compiler’. Hopefully the newer language offers better latitude for different approaches ?

Thanks.

Personally, I would keep these at module-level – certainly to start. Semantics can be extended over time but it is hard to pull them back.

I think we can leave this undefined. A runtime that has an external function (say resetState()) can be defined to do whatever it wants (typically on whatever the lowered form ends up being). The aliasing restriction only applies to interactions with code defined in the module.

Back to the question of initialization, I believe you will very quickly need to define forms that allow setting the initial value (or provide a more elaborate mechanism for doing module level initialization). What you have is narrowly sufficient for some of the existing stateful ops but is insufficient for anything more broadly (i.e. training variables, etc). It is also a pretty normal thing in these levels of IR to be able to have explicit initialization.

From an IR design standpoint, accepting a DenseElementsAttr initial_value attribute gives a nice way to make explicit the zero initialization case without requiring words to describe this state (i.e. a dense<0> is a strong/explicit signal of what the semantic is, and it extends to further cases).

Sounds good.

Great, we’re in agreement again.

I’m in support of having an (optional) initialization value. Reviewing internal feedback, not having an initial value is based around a simple memory model where the stateful tensors are also passed in as module parameters along with input tensors.

However this isn’t semantically equivalent as the input tensors are rvalue expressions passed in, while the persistent tensors are lvalue expressions that should be assignable. I’m trying to get this topic resolved.