Skip to content

Backend framework comparison

Albert Zeyer edited this page Nov 8, 2021 · 9 revisions

We want to compare TensorFlow, PyTorch, JAX and maybe other similar frameworks here.

We do not want to compare higher-level frameworks like Keras here.

Side node: The classification low-level, mid-level and high-level framework comes from our tutorial (video, slides). Maybe what we call "backend framework" here corresponds most closely to "mid-level framework". But the distinction is not always clear. Here we mean all frameworks which could potentially be used as backend for RETURNN.

This comparison is specifically about their properties specifically when used as a backend framework for RETURNN, using the RETURNN principles and core features.

Theano

  • Our initial backend.
  • The most widely used backend framework around ~2016-2017 maybe?
  • Performs optimizations on the computation graph. This could improve the training runtime but often would take a lot of time for optimizing and compiling the computation graph (in the order of minutes).
  • It supports non-contiguous tensors (by storing strides). Any op needing a contiguous tensor would make it contiguous explicitly before. The user usually never needs to care about this.
  • It supports inplace operations. The automatic optimization would find tensors which are not used after some op anymore and replace the op by an inplace op when possible. This can reduce memory and sometimes also runtime.

TensorFlow

  • Our second backend.
  • The most widely used backend framework around ~2017-2019 maybe?
  • Performs only minimal optimizations on the computation graph, such that this step is almost instant.
  • Does not support non-contiguous tensors (see here, here, here). This has a couple of implications: It will perform a copy of a tensor for operations like tf.transpose.
  • It does not support inplace operations.

PyTorch

  • Probably the most widely used framework since around ~2019?
  • Eager-mode is first class and maybe one reason people like it.
  • Has a very clean API, including also higher-level API (all in torch.nn, the base class nn.Module, and also things like LSTM and Transformer). This is somewhat similar to TF Keras but arguably cleaner. This is another reason why people like it.
  • It supports non-contiguous tensors. transpose is always just a view, i.e. a very cheap op.

JAX

  • Recently gained lots of interest in the community.