Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

I’ve only been reading through the docs for a few moments, but I’m pleasantly surprised to find they the authors are using effect handlers to handle effectful computations in ML models. I was in the process of translating a model from torch to Jax using Equinox, this makes me think penzai could be a better choice.


I was just reading this too! I think it's a really interesting choice in the design space.

So to elucidate this a little bit, the trade-off is that this is now incompatible with e.g. `jax.grad` or `lax.scan`: you can't compose things in the order `discharge_effect(jax.grad(your_model_here))`, or put an effectful `lax.scan` inside your forward pass, etc. The effect-discharging process only knows how to handle traversing pytree structures. (And they do mention this at the end of their docs.)

This kind of thing was actually something I explicitly considered later on in Equinox, but in part decided against as I couldn't see a way to make that work either. The goal of Equinox was always absolute compatibility with arbitrary JAX code.

Now, none of that should be taken as a bash at Penzai! They've made a different set of trade-offs, and if the above incompatibility doesn't affect your goals then indeed their effect system is incredibly elegant, so certainly give it a try. (Seriously, it's been pretty cool to see the release of Penzai, which explicitly acknowledges how much it's inspired by Equinox.)


Author of Penzai here! In idiomatic Penzai usage, you should always discharge all effects before running your model. While it's true you can't do `discharge_effect(jax.grad(your_model_here))`, you can still do `jax.grad(discharge_effect(your_model_here))`, which is probably what you meant to do anyway in most cases. Once you've wrapped your model in a handler layer, it has a pure interface again, which makes it fully compatible with all arbitrary JAX transformations. The intended use of effects is as an internal helper to simplify plumbing of values into and out of layers, not as something that affects the top-level interface of using the model!

(As an example of this, the GemmaTransformer example model uses the SideInput effect internally to do attention masking. But it exposes a pure functional interface by using a handler internally, so you can call it anywhere you could call an Equinox model, and you shouldn't have to think about the effect system at all as a user of the model.)

It's not clear to me what the semantics of ordinary JAX transformations like `lax.scan` should be if the model has side effects. But if you don't have any effects in your model, or if you've explicitly handled them already, then it's perfectly fine to use `lax.scan`. This is similar to how it works in ordinary JAX; if you try to do a `lax.scan` over a function that mutates Python state, you'll probably hit an error or get something unexpected. But if you mutate Python state internally inside `lax.scan`, it works fine.

I'll also note that adding support for higher-order layer combinators (like "layer scan") is something that's on the roadmap! The goal would be to support some of the fancier features of libraries like Flax when you need them, while still admitting a simple purely-functional mental model when you don't.


Thanks! This is one of the more experimental design choices I made in designing Penzai, but so far I've found it to be quite useful.

The effect system does come with a few sharp edges at the moment if you want to use JAX transformations inside the forward pass of your model (see my reply to Patrick), but I'm hoping to make it more flexible as time goes on. (Figuring out how effect systems should compose with function transformations is a bit nontrivial!)

Please let me know if you run into any issues using Penzai for your model! (Also, most of Penzai's visualization and patching utilities should work with Equinox too, so you shouldn't necessarily need to fully commit to either one.)


This something I’ve thought about in the past, since I messed around with trying to add monads to JAX - I think you made the right call with effect handlers. You might want to take a look at what Koka does, that was the best implementation of effect handlers the last time I checked.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: