-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inquiry] Document Masking and Assigning Different Weights #88
Comments
Yeah this would likely need both a score and a document mask. Lets say that you had static scores and you have already generated your You can create another can look like def score_mod(score, b, h , q_idx, kv_idx):
score = score * doc_score[q_idx] And lets the mask mod handle the masking out of irrelevant scores |
Thank you so much for your prompt and helpful reply! Over the past few days, I have been trying to implement your suggestion and have made some progress. However, I’ve encountered a few issues that I’m hoping you could help clarify. Currently, the function I wrote to generate However, my main challenge arises during the decoding process in a LLM, where multiple tokens are generated sequentially. This causes Could you provide further guidance on how to handle this dynamic change in I would greatly appreciate any insights or suggestions you might have. Thank you once again for your time and support! |
Hey here is some example code on how you can grow your lookup as your sequence legnth increases during decoding from functools import partial
import torch
from torch.nn.attention.flex_attention import flex_attention
lookup = torch.randn(20, device="cuda")
def score_mod(score, b, h, q, k):
return score * lookup[q]
make_tensor = partial(torch.rand, device="cuda", dtype=torch.float32)
# Without `dynamic = True`` this it will recompile
flex_compiled = torch.compile(flex_attention, fullgraph=True, dynamic=True)
q, k, v = (
make_tensor(1, 1, 20, 16),
make_tensor(1, 1, 20, 16),
make_tensor(1, 1, 20, 16),
)
out = flex_compiled(q, k, v, score_mod=score_mod)
print(out.shape)
lookup = torch.cat((lookup, torch.randn(1, device="cuda")), dim=-1)
q, k, v = (
make_tensor(1, 1, 21, 16),
make_tensor(1, 1, 21, 16),
make_tensor(1, 1, 21, 16),
)
out = flex_compiled(q, k, v, score_mod=score_mod)
print(out.shape) torch.Size([1, 1, 20, 16])
torch.Size([1, 1, 21, 16]) |
Dear Developer, When running your code, some issues arise: Traceback (most recent call last):
File "/home/user/modded-nanogpt/attention-gym/attn_gym/masks/toy.py", line 26, in <module>
out = flex_compiled(q, k, v, score_mod=score_mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 573, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1379, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2861, in run
super().run()
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1053, in run
while self.step():
^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 963, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3041, in RETURN_VALUE
self._return(inst)
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3026, in _return
self.output.compile_subgraph(
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1087, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1361, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1411, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1441, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/__init__.py", line 2308, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1811, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 73, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1102, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1078, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 628, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6668, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/interpreter.py", line 228, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/fx/interpreter.py", line 308, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 90, in __call__
return super().__call__(
^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_ops.py", line 440, in __call__
return wrapper()
^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 744, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_ops.py", line 436, in wrapper
return self.dispatch(
^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_ops.py", line 419, in dispatch
return kernel(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 710, in flex_attention_autograd
input_requires_grad = any(
^^^^
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/_higher_order_ops/flex_attention.py", line 711, in <genexpr>
t.requires_grad for t in (query, key, value, *score_mod_other_buffers)
^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'SymInt' object has no attribute 'requires_grad'
While executing %flex_attention : [num_users=1] = call_function[target=torch.ops.higher_order.flex_attention](args = (%l_query_, %l_key_, %l_value_, %score_mod_0, (%child, %child_1, None, None, %q_num_blocks, %q_indices, None, None, 1073741824, 1073741824, %mask_fn_0), %truediv, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, OUTPUT_LOGSUMEXP: True}, (%s0, %g_import_main_lookup), ()), kwargs = {})
Original traceback:
File "/home/user/micromamba/envs/torchdev/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1286, in flex_attention
out, lse = flex_attention_hop(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True And I am not sure what is the problem. Here is my torch version: |
hmm, the exact version of my code? I recently landed some fixes to how we handle dynamic shapes can you try a more recent version of nightly: |
Yes, I am using the exact version of your code, would try the newer versions now. |
The code provided with the updated torch works fine now :-) Thank you so much for your prompt help, and I will try to write my version of the code tomorrow. Thank you again!! |
Dear Developers,
Thank you for creating flex attention. I believe this is an excellent work and fits well with my current research. Recently, I have been playing with this module and encountered some issues. Please forgive me as I am new to this field.
My question is related to document masking. I have been studying
attention-gym/attn_gym/masks/document_mask.py
, but I would like to ask how to assign different weights to different documents. This might involve both the mask and mod, which I am not very familiar with. For example, how can I apply operations like multiplying all values withinq: [3,7], kv: [3,7]
by 0.5, or withinq: [10,15], kv: [10,15]
by 0.7?I would greatly appreciate it if you could spare some time to clarify this for me. This would be incredibly helpful for my work. Thank you very much!
The text was updated successfully, but these errors were encountered: