When trying to export model containing symbolic / non-fixed shapes, the exporter bakes in the lowest possible shape (i.e., 1).
Colab to reproduce:
https://colab.research.google.com/drive/1H81pAhOWr42YFnzrAT4zGI3XooXvM8oe
Suggestions:
model.export() already supports **kwargs, and torch specifies dynamic_shapes / dynamic_axes for torch.onnx.export(), so just expose those, and it would work fine as is
- (better ux:) calculate
dynamic_shapes / dynamic_axes automatically based off the input spec
- or all of the above