Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Usage of Sparse Tensors ? #1

Open
karanchahal opened this issue Jul 11, 2019 · 15 comments
Open

Usage of Sparse Tensors ? #1

karanchahal opened this issue Jul 11, 2019 · 15 comments
Assignees
Labels
question Further information is requested

Comments

@karanchahal
Copy link

karanchahal commented Jul 11, 2019

Hello,

I see you've been using boolean masks to mask out the weights of the pytorch network. Is there a way to use sparse tensors to achieve an actual speed up in inference in Pytorch currently ?

Thank you for your nice library :) I find unstructured pruning so fascinating, it's a monumental speedup waiting to happen .

@TimDettmers
Copy link
Owner

I was quite interested in using sparse tensors myself, and I also worked on sparse GPU algorithms, but unfortunately, it is a very complicated issue, especially if you want to keep all speedups while preserving the structure of PyTorch. Currently, you will not be able to gain speedups by using sparse tensors.

One problem are sparse gradients: If you have sparse-dense matrix multiplications or convolutions in the backward pass you will not gain speedups. However, adding sparsity in the backward pass in a clean way would entail complicated modifications to the PyTorch backend. Either you add sparse forward, and backward operations for each layer (that is a lot of code), or you add distinct sparse layers and somehow marry dense and sparse backward-forward passes under a general framework. Both approaches are a mess!

Another issue are speedups themselves. The current cuBLAS sparse matrix multiplication algorithm is bad for deep neural networks (better for high-performance computing workloads). I worked on a more efficient matrix multiplication algorithm in 16-bit and developed an algorithm which is faster than cuBLAS, but this algorithm is only suitable for sparse matrix multiplication. Unlike the dense case, I do not think we can use col2img style operations to perform efficient sparse convolution. My sparse matrix multiplication might be interesting to transformers, but an efficient sparse convolution is still further away.

One potential solution is to compromise the fine-grained structure to enable simpler, more efficient sparse algorithms. Scott Gray developed block-sparse kernels, but these do not work for algorithms like sparse momentum. However, a different approach might be to modify algorithms like sparse momentum to use a little bit more structure. I experimented with that: I had an algorithm which makes growth decisions for an entire filter in a convolutional network (instead of individual weights within a filter), but that algorithm did not perform well. Mostafa & Wang (2019) made similar observations. So it is unclear if an approach with more structure can be competitive with the more fine-grained sparsity algorithms like sparse momentum.

With Graphcores on the horizon, things become even more unclear. Is the effort put into GPU algorithms worth it? I think, for now, it might be practical to use inefficient dense tensor processing to "simulate" sparse networks to do research. Once we know more about successful procedures of how to do sparse learning, we have a better knowledge of what kind of GPU algorithms we need and also how to best design deep learning libraries to support sparse computation.

@karanchahal
Copy link
Author

Thank you for your very well thought out answer. I had no idea about the various pain points with regards to getting speedups in sparse networks. I was mainly thinking of accelerating inference as of now and not worry about training presently.

Is it possible to get a speed up in dense sparse matrix multiplication ? I was trying to reason it out using a naive example but am unable to figure out how it can be done, it seems like you'll have to load up both matrices in memory.

Do you have a link to a paper/ code that details your sparse matrix multiplication ?

Also as a side note, are you participating in the Neurips 2019 Micronet challenge ?

@karanchahal
Copy link
Author

I just read your blog post of sparse learning, I'm curious did you test against the "To Prune or not to Prune" paper by the Google people (they use it in the tf model optimisation toolkit) ? How did it do against that algorithm ? I would be very curious to see this algorithm applied to object detection (currently working on this actually although training a mask rcnn from scratch is outside my computational budget!).

I was however able to achieve a nice 80% sparsity (with very minimal accuracy degradation, I suspect with more training it'll come up to dense accuracies) in the convolutional filters of a pre trained mask rcnn using the above mentioned algorithm.

Your paper was very interesting and something a lot more people should be talking about !

Cheers :)

@TimDettmers
Copy link
Owner

Thank you for your kind words! I did not read the paper, but I just skimmed it, and the main difference between "To Prune or not to Prune" and our work is that we maintain sparsity from beginning to end while they go from dense to sparse. Going from dense to sparse has an advantage for predictive performance but does not enable faster training for the majority of the training.

In my work, I also experimented with a mix of these two methods: Start with, say, 50% weights and prune more weights than you regrowth so that you finish with, say, 25% weights. The predictive performance is much better than starting from 25% weights and one still gains some speedups. However, the entire procedure is a bit more complicated and that is why I did not include it in the paper.

@TimDettmers TimDettmers self-assigned this Jul 15, 2019
@TimDettmers TimDettmers added the question Further information is requested label Jul 15, 2019
@karanchahal
Copy link
Author

Ah interesting. Is there a way the structured pruning approach rivals the accuracies of unstructured pruning ? Have you tried pruning an index of a dimension (say a conv channel filter) slowly ? Start with say 25% of the weights in a certain index of a dimension, and then slowly go up till 100% so that we get some block sparsity. The index of the dimension can be selected using some threshold.

@TimDettmers
Copy link
Owner

This is a very interesting idea. I think from a computational perspective this would be a very promising direction — if successful, it would definitely help us to develop and use much less complex sparse algorithms which perform better.

I currently do not have time to look at this since I will start on a different project (I plan to return to sparse learning by the end of the year), but I will refactor the library at the moment and it will be much easier to implement new algorithms like this — so you might want to give it a try yourself at that time. I will let you know once I implemented these changes.

@mitchellgordon95
Copy link

mitchellgordon95 commented Aug 8, 2019

Popping in to mention that sparse tensors would be helpful not just for training speed-ups, but also for GPU memory consumption.

For example, my group has been struggling with training BERT. The large version won't allow training with a batch size of even a single example on a normal GPU... you need a TPU pod to get started. That rules out methods like gradient accumulation, although gradient check-pointing can help.

If we could train sparse, though, I think that would allow us to use commodity hardware. Am I wrong?

@TimDettmers
Copy link
Owner

One general problem regarding memory saving is that even though your weights are sparse you still get dense outputs if you have relatively dense inputs. This will probably change in the future as people work on sparse feature algorithms, but currently, you would only save the memory of the weights. If you look at BERT however, you find that 2/3 of memory just comes from the activations and storing the gradients. Since these are dense anyway, my method would not help here at the moment.

@varun19299
Copy link

This library (Sparse Linear) uses sparse tensors for MLPs. It seems to be more efficient for storage (expected), and also for inference / training (not entirely sure how, given it's unstructured sparsity).

Is this feasible for convolutions?

Perhaps for convolutions, you could skip certain channel connections entirely?

@TimDettmers
Copy link
Owner

If you use the GEMM formulation for convolution you could use a very similar code for convolution. So you could use similar code and just add an img2col and col2img function. In this way, you could reuse the sparse linear library.

Unstructured sparsity can yield faster training through a regular sparse matrix multiplication as implemented in cuSPARSE. But speedups are not linear. If you have 10% of the weight you can probably expect a speedup of 2-4x or so.

@varun19299
Copy link

Thank you for your reply!

Yes, I was thinking of using GEMM, maybe something like this.

Not completely pertinent to this repo, but I played around with Sparse Linear, and for lower sparsities (< 90%), its actually worse than dense inference or training. The time seems quite off from what cuSPARSE could have given.

What are your thoughts on this? Yes, I should probably recheck with a few newer GPUs (I tried this on Colab with a P100). Is it possible that torch.sparse and pytorch_sparse are quite suboptimal at the moment?

@TimDettmers
Copy link
Owner

cuSPARSE performance can be quite weak for certain matrix sizes and sparsity values and easily be slower than dense. Do you have some numbers for this, that is speed/time taken and size of the operations? I suspect the difference will grow even further with newer GPUs since sparse patterns do not easily benefit from tensor cores.

@varun19299
Copy link

varun19299 commented Oct 6, 2020

Here, I tried a linear layer with input dimensions as 2000, output as 2000, and sparsity of 90%.

Speedups are seen for >99% sparsity.

The input features are dense, so it is a sparse-matrix dense-matrix multiplication.

@TimDettmers
Copy link
Owner

The example that you posted has a couple of problems. For benchmarking you need to use CUDA streams for precise timing since kernel executions are asynchronous. You also need to burn in the GPU with a couple of iterations before benchmarking or the first functions that you execute will be slower. The dimensions are also a bit off. Your batch size should be at least 32 to make efficient use of GPU resources. Would be curious to new benchmarks with those problems fixed.

@varun19299
Copy link

varun19299 commented Oct 22, 2020

Yes, I didn't benchmark it accurately (with CUDA streams, GPU warm up, etc.).

I'll do that soon with dimensions in powers of 2, although I don't suspect it will change the end observation: the timings were a bit too off. I'll update you with these timings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants