Skip to content

Commit

Permalink
clean up doc string
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 7, 2024
1 parent 757b03d commit 21f1235
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions attn_gym/mods/softcapping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implementation of an tanh softcapping score mod popularized in Gemma2 paper."""
"""Implementation of tanh softcapping score mod popularized in Gemma2 and Grok-1"""

import torch
from torch import Tensor
Expand All @@ -11,21 +11,21 @@


@torch.library.custom_op("approx::tanh", mutates_args=())
def tanh_approx(inp: Tensor) -> Tensor:
def _tanh_approx(inp: Tensor) -> Tensor:
return torch.tanh(inp)


@tanh_approx.register_fake
@_tanh_approx.register_fake
def _(inp: torch.Tensor) -> torch.Tensor:
return torch.tanh(inp)


def tanh_approx_lowering(inp):
def _tanh_approx_lowering(inp):
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
return make_pointwise(fn)(inp)


register_lowering(torch.ops.approx.tanh)(tanh_approx_lowering)
register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)


class _TanhApprox(torch.autograd.Function):
Expand Down

0 comments on commit 21f1235

Please sign in to comment.