That's cool -- but wouldn't it be more constructive to discuss "the ideas" in this package anyways?
For instance, it would be interesting to discern if the design of PyTorch (and their modules) preclude or admit the same sort of visualization tooling? If you have expertise in PyTorch, perhaps you could help answer this sort of question?
JAX's Pytrees are like "immutable structs, with array leaves" -- does PyTorch have a similar concept?
Idk if you need that immutability actually. You could probably reconstruct enough to do this kind of viz from the autograd graph, or capture the graph and intermediates in the forward pass using hooks. My hunch is it should be doable.
For instance, it would be interesting to discern if the design of PyTorch (and their modules) preclude or admit the same sort of visualization tooling? If you have expertise in PyTorch, perhaps you could help answer this sort of question?
JAX's Pytrees are like "immutable structs, with array leaves" -- does PyTorch have a similar concept?