Skip to content

Commit

Permalink
[FEAT]: golang device slice ranges (#463)
Browse files Browse the repository at this point in the history
## Describe the changes

This PR adds the capability to slice a DeviceSlice, allowing portions of
data that are already on the device to be reused.

Additionally, this PR removes the need for a HostSlice underlying type
to implement a Size function and uses unsafe.Sizeof instead. This
together with #407 will allow direct usage of gnark-crypto types with
HostSlice without the need for converting to ICICLE types

---------

Co-authored-by: nonam3e <[email protected]>
  • Loading branch information
2 people authored and yshekel committed May 19, 2024
1 parent 04b97d4 commit ff940d2
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 13 deletions.
81 changes: 68 additions & 13 deletions wrappers/golang/core/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,69 @@ func (d DeviceSlice) CheckDevice() {
}
}

func (d *DeviceSlice) Range(start, end int, endInclusive bool) DeviceSlice {
if end <= start {
panic("Cannot have negative or zero size slices")
}

if end >= d.length {
panic("Cannot increase slice size from Range")
}

var newSlice DeviceSlice
switch {
case start < 0:
panic("Negative value for start is not supported")
case start == 0:
newSlice = d.RangeTo(end, endInclusive)
case start > 0:
tempSlice := d.RangeFrom(start)
newSlice = tempSlice.RangeTo(end-start, endInclusive)
}
return newSlice
}

func (d *DeviceSlice) RangeTo(end int, inclusive bool) DeviceSlice {
if end <= 0 {
panic("Cannot have negative or zero size slices")
}

if end >= d.length {
panic("Cannot increase slice size from Range")
}

var newSlice DeviceSlice
sizeOfElement := d.capacity / d.length
newSlice.length = end
if inclusive {
newSlice.length += 1
}
newSlice.capacity = newSlice.length * sizeOfElement
newSlice.inner = d.inner
return newSlice
}

func (d *DeviceSlice) RangeFrom(start int) DeviceSlice {
if start >= d.length {
panic("Cannot have negative or zero size slices")
}

if start < 0 {
panic("Negative value for start is not supported")
}

var newSlice DeviceSlice
sizeOfElement := d.capacity / d.length

newSlice.inner = unsafe.Pointer(uintptr(d.inner) + uintptr(start)*uintptr(sizeOfElement))
newSlice.length = d.length - start
newSlice.capacity = d.capacity - start*sizeOfElement

return newSlice
}

// TODO: change signature to be Malloc(element, numElements)
// calc size internally
func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) {
dp, err := cr.Malloc(uint(size))
d.inner = dp
Expand Down Expand Up @@ -90,20 +153,13 @@ func (d *DeviceSlice) FreeAsync(stream cr.Stream) cr.CudaError {
return err
}

type HostSliceInterface interface {
Size() int
}

type HostSlice[T HostSliceInterface] []T

func HostSliceFromElements[T HostSliceInterface](elements []T) HostSlice[T] {
slice := make(HostSlice[T], len(elements))
copy(slice, elements)
type HostSlice[T any] []T

return slice
func HostSliceFromElements[T any](elements []T) HostSlice[T] {
return elements
}

func HostSliceWithValue[T HostSliceInterface](underlyingValue T, size int) HostSlice[T] {
func HostSliceWithValue[T any](underlyingValue T, size int) HostSlice[T] {
slice := make(HostSlice[T], size)
for i := range slice {
slice[i] = underlyingValue
Expand All @@ -129,7 +185,7 @@ func (h HostSlice[T]) IsOnDevice() bool {
}

func (h HostSlice[T]) SizeOfElement() int {
return h[0].Size()
return int(unsafe.Sizeof(h[0]))
}

func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *DeviceSlice {
Expand All @@ -142,7 +198,6 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic
panic("Number of bytes to copy is too large for destination")
}

// hostSrc := unsafe.Pointer(h.AsPointer())
hostSrc := unsafe.Pointer(&h[0])
cr.CopyToDevice(dst.inner, hostSrc, uint(size))
dst.length = h.Len()
Expand Down
32 changes: 32 additions & 0 deletions wrappers/golang/core/slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ func TestHostSlice(t *testing.T) {
hostSlice := HostSliceFromElements(randFields)
assert.Equal(t, hostSlice.Len(), 4)
assert.Equal(t, hostSlice.Cap(), 4)

hostSliceCasted := (HostSlice[internal.MockField])(randFields)
assert.Equal(t, hostSliceCasted.Len(), 4)
assert.Equal(t, hostSliceCasted.Cap(), 4)
}

func TestHostSliceIsEmpty(t *testing.T) {
Expand Down Expand Up @@ -190,3 +194,31 @@ func TestCopyToFromHostDeviceProjectivePoints(t *testing.T) {

assert.Equal(t, hostSlice, hostSlice2)
}

func TestSliceRanges(t *testing.T) {
var deviceSlice DeviceSlice

numPoints := 1 << 3
randProjectives := randomProjectivePoints(numPoints, fieldSize)
hostSlice := (HostSlice[internal.MockProjective])(randProjectives)
hostSlice.CopyToDevice(&deviceSlice, true)

// RangeFrom
var zeroProj internal.MockProjective
hostSliceRet := HostSliceWithValue[internal.MockProjective](zeroProj, numPoints-2)

deviceSliceRangeFrom := deviceSlice.RangeFrom(2)
hostSliceRet.CopyFromDevice(&deviceSliceRangeFrom)
assert.Equal(t, hostSlice[2:], hostSliceRet)

// RangeTo
deviceSliceRangeTo := deviceSlice.RangeTo(numPoints-3, true)
hostSliceRet.CopyFromDevice(&deviceSliceRangeTo)
assert.Equal(t, hostSlice[:6], hostSliceRet)

// Range
hostSliceRange := HostSliceWithValue[internal.MockProjective](zeroProj, numPoints-4)
deviceSliceRange := deviceSlice.Range(2, numPoints-3, true)
hostSliceRange.CopyFromDevice(&deviceSliceRange)
assert.Equal(t, hostSlice[2:6], hostSliceRange)
}

0 comments on commit ff940d2

Please sign in to comment.