Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparse]update support for arbitrary N:M settings sparse #1631

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

LeiWang1999
Copy link

Motivation

Recently heaps of studies have been concentrating on n:m sparsity, and some kernel libraries of n:m also popping up too, and they are not only suitable for 2:4 sparsity, like n:m sparsity gpu kernels on mlsys 2023, they leveraged m32 to avoid bank conflict and have sota performance on sparse gemv, and matmul in some sparse ratio. However they need support for arbitrary N:M trainning but currently the mask_calculator of only support element-wise sparsity pattern.

In order to support generating general N:M sparse weights, This pull request made 3 major extensions to ASP

  1. The original ASP only implements pruning with 2:4 sparsity and fails to set M larger than 10 because of the itertools.permutations the complexity is O(N!), mask generation above 10 will be unacceptable. we extend it to support arbitrary N:M settings by reduce the complexity into O(N). Normally we set M = 32, as 32 is an enough window side to cover typical sparsity ratios in DNN.

    """ m:n 1d structured pruning: greedy method to select mask """
    def mn_1d_greedy(matrix, m, n):
        mat, shape = reshape_1d(matrix,m)
        mask = torch.cuda.IntTensor(matrix.shape).fill_(0).view(-1,m)
    
        values, indices = torch.abs(mat).topk(n, dim=1, largest=True)
        indexes = torch.arange(0, indices.shape[0], step=1, dtype=torch.long).view(-1, 1)
    
        mask[indexes, indices] = 1
        mask = mask.view(matrix.shape)
    
        return mask.cuda()
  2. We extend ASP to support VW- and BW-N:M sparsity for nmSPARSE to achieve higher speedup.

    def unstructured_vector_wise(matrix, density, v):
        mat = matrix.view(-1, v)
        (m, v) =  mat.shape;
        n = int(m * density)
    
        mask = torch.cuda.IntTensor(mat.shape).fill_(0)
        mat_reduce = torch.sum(mat, dim=-1)
        values, indices = torch.abs(mat_reduce).topk(n, dim=0, largest=True)
    
        mask[indices, :] = 1;
        mask = mask.view(matrix.shape)
        return mask
    
    def unstructured_block_wise(matrix, density, bh, bw):
        # split into tensor blocks
        mat, shape = tensor_block_partition(matrix, bh, bw)
        (bm, bn, bh, bw) = mat.shape
        mask = torch.cuda.IntTensor(mat.shape).fill_(0)
        mat_abs = torch.abs(mat)
        mat_reduce = torch.sum(torch.sum(mat_abs, dim=-1), dim=-1)
        mat_reduce_recover = torch.stack(tuple(mat_reduce), dim=-1).view(-1)
        # n = int(bm * bn * density)
        n = int(bm * bn *  density)
        values, indices = torch.topk(mat_reduce_recover, n, dim=-1, largest=True)
        # todo: this can be optimize, currently is slow.
        for d0 in indices:
            mask[d0 // bn][d0 % bn][:][:] = 1
        # mask[0, 0, 0, indices] = 1;
        mask = torch.cat(tuple(mask), 2).view(matrix.shape)
    
        return mask.cuda()
  3. We further enable simple layer-wise sparsity ratio configuration for ASP because various DNN layers favor different sparsity ratios, which config will keep the first layer_wise_ratio % percent weight in dense, and the last keep the sparsity pattern.

    def prune_trained_model(cls, model, optimizer, mask_calculator="m4n2_1d", sparsity=0.5, layer_wise_ratio=0.0):
            # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
            cls.init_model_for_pruning(model, mask_calculator, (1 - sparsity), verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
            cls.init_optimizer_for_pruning(optimizer)
            cls.compute_sparse_masks(layer_wise_ratio)

image

How to use?

The changes in this PR are backward compatible with previous code based on Apex, as the default value for layerwise_ratio is set to 0%.

ASP.prune_trained_model(model, optimizer, mask_calculator=args.mask_calculator, sparsity=args.sparsity, layer_wise_ratio=args.layer_wise_ratio)

for exampe, if the layer_wise_ratio is set to 0.25, then the first 6 layers of bert-large will keep the weight fixed in dense, but the last 18 layer will leverage the given sparse pattern.

Accuracy Performance on End2End

Bert-large on SQuAD-1.1, with layer_wise_ratio fixed in 0.166 ( 4 of 24 is dense):

F1 scores under different N:M settings. Under a fixed sparsity ratio, different N:M settings have no obvious impact on model accuracy.

image

Comparing F1 scores between with and without N:M distribution for EW, VW, and BW sparsity. Adding N:M distribution has no obvious or deterministic impact on model accuracy.

image

@crcrpar
Copy link
Collaborator

crcrpar commented Apr 18, 2023

cc @jpool-nv @ChongyuNVIDIA could you review this?

cur_pattern = np.zeros(m, dtype=np.int32)
cur_pattern[list(i)] = 1
valid_patterns.append(cur_pattern)
valid_patterns = torch.Tensor(np.array(valid_patterns))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @LeiWang1999 - as I suggested in your previous PR 1501, I'd prefer if the existing 2:4 mask candidate orders were maintained. We can accomplish this with:
valid_patterns = torch.Tensor(list(set([tuple(vp) for vp in valid_patterns])))

return mask

""" m:n 1d structured pruning: greedy method to select mask """
def mn_1d_greedy(matrix, m, n):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on why the mn_1d_greedy method is needed, and/or how its masks differ from the mn_1d_best mask for a given matrix?

return mask.cuda()


def m32n3_1d_best(mat, density):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a focus on 2:4 sparsity, the burden of remapping from calculator string to function via func = getattr(sys.modules[__name__], pattern, None) wasn't so painful. With so many N:M options in use, it may make sense to find a better way to handle this; what do you think?

Maybe parsing the calculator string for m, n, dimensionality, and any extra tag ("_best", "_greedy", etc.) then calling mn_xd_tag(mat, m, n) directly could work and handle any future extensions.

@jpool-nv
Copy link
Contributor

Hi @LeiWang1999 - thanks for this PR! I think there are some good things within.

For each extension in turn:

  1. This makes perfect sense. Thanks again for the improvements.
  2. I'm a little confused by the results in the table you shared; the N:M vector- and block-wise results are sometimes significantly better than the corresponding unconstrained vector- and block-wise results. I looked for the MLSys work you referenced to try to get some insight behind the changes and improvements, but it doesn't seem to be available, yet. Is it possible to share a pre-print with me by email, or did my search miss a public version?
  3. This same functionality, and more, is already available in the allowed_layer_names and disallowed_layer_names parameters. This existing interface is more predictable, too, since it doesn't rely on the order in which modules were added to the __sparse_parameters class member, nor how many non-sparse layers were present before the sparse layers begin. (For instance, ASP can decide not to sparsify layers based on channel counts, which could change which layers are later ignored by the layer_wise_ratio.) Is there something this approach can do that the existing interface cannot? If it's just convenience, I'm not sure it's worth the extra complexity, potentially unexpected behavior, and ongoing maintenance costs.

@LeiWang1999
Copy link
Author

Thanks @jpool-nv , Thank you for your suggestion, these are indeed something to be considered more carefully, and I will have some code updates in next few weeks if I have time. As for the paper, there is still no version available for public access, and it will be available in a few weeks I believe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants