From ff940d2ba83e93ebc0b6562742df8aabd9d1f9c2 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Mon, 8 Apr 2024 19:42:03 +0300 Subject: [PATCH] [FEAT]: golang device slice ranges (#463) ## Describe the changes This PR adds the capability to slice a DeviceSlice, allowing portions of data that are already on the device to be reused. Additionally, this PR removes the need for a HostSlice underlying type to implement a Size function and uses unsafe.Sizeof instead. This together with #407 will allow direct usage of gnark-crypto types with HostSlice without the need for converting to ICICLE types --------- Co-authored-by: nonam3e --- wrappers/golang/core/slice.go | 81 +++++++++++++++++++++++++----- wrappers/golang/core/slice_test.go | 32 ++++++++++++ 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/wrappers/golang/core/slice.go b/wrappers/golang/core/slice.go index 8ec1c1f92..56ee82064 100644 --- a/wrappers/golang/core/slice.go +++ b/wrappers/golang/core/slice.go @@ -54,6 +54,69 @@ func (d DeviceSlice) CheckDevice() { } } +func (d *DeviceSlice) Range(start, end int, endInclusive bool) DeviceSlice { + if end <= start { + panic("Cannot have negative or zero size slices") + } + + if end >= d.length { + panic("Cannot increase slice size from Range") + } + + var newSlice DeviceSlice + switch { + case start < 0: + panic("Negative value for start is not supported") + case start == 0: + newSlice = d.RangeTo(end, endInclusive) + case start > 0: + tempSlice := d.RangeFrom(start) + newSlice = tempSlice.RangeTo(end-start, endInclusive) + } + return newSlice +} + +func (d *DeviceSlice) RangeTo(end int, inclusive bool) DeviceSlice { + if end <= 0 { + panic("Cannot have negative or zero size slices") + } + + if end >= d.length { + panic("Cannot increase slice size from Range") + } + + var newSlice DeviceSlice + sizeOfElement := d.capacity / d.length + newSlice.length = end + if inclusive { + newSlice.length += 1 + } + newSlice.capacity = newSlice.length * sizeOfElement + newSlice.inner = d.inner + return newSlice +} + +func (d *DeviceSlice) RangeFrom(start int) DeviceSlice { + if start >= d.length { + panic("Cannot have negative or zero size slices") + } + + if start < 0 { + panic("Negative value for start is not supported") + } + + var newSlice DeviceSlice + sizeOfElement := d.capacity / d.length + + newSlice.inner = unsafe.Pointer(uintptr(d.inner) + uintptr(start)*uintptr(sizeOfElement)) + newSlice.length = d.length - start + newSlice.capacity = d.capacity - start*sizeOfElement + + return newSlice +} + +// TODO: change signature to be Malloc(element, numElements) +// calc size internally func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) { dp, err := cr.Malloc(uint(size)) d.inner = dp @@ -90,20 +153,13 @@ func (d *DeviceSlice) FreeAsync(stream cr.Stream) cr.CudaError { return err } -type HostSliceInterface interface { - Size() int -} - -type HostSlice[T HostSliceInterface] []T - -func HostSliceFromElements[T HostSliceInterface](elements []T) HostSlice[T] { - slice := make(HostSlice[T], len(elements)) - copy(slice, elements) +type HostSlice[T any] []T - return slice +func HostSliceFromElements[T any](elements []T) HostSlice[T] { + return elements } -func HostSliceWithValue[T HostSliceInterface](underlyingValue T, size int) HostSlice[T] { +func HostSliceWithValue[T any](underlyingValue T, size int) HostSlice[T] { slice := make(HostSlice[T], size) for i := range slice { slice[i] = underlyingValue @@ -129,7 +185,7 @@ func (h HostSlice[T]) IsOnDevice() bool { } func (h HostSlice[T]) SizeOfElement() int { - return h[0].Size() + return int(unsafe.Sizeof(h[0])) } func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *DeviceSlice { @@ -142,7 +198,6 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic panic("Number of bytes to copy is too large for destination") } - // hostSrc := unsafe.Pointer(h.AsPointer()) hostSrc := unsafe.Pointer(&h[0]) cr.CopyToDevice(dst.inner, hostSrc, uint(size)) dst.length = h.Len() diff --git a/wrappers/golang/core/slice_test.go b/wrappers/golang/core/slice_test.go index c7c44cbcd..6f9709cc7 100644 --- a/wrappers/golang/core/slice_test.go +++ b/wrappers/golang/core/slice_test.go @@ -81,6 +81,10 @@ func TestHostSlice(t *testing.T) { hostSlice := HostSliceFromElements(randFields) assert.Equal(t, hostSlice.Len(), 4) assert.Equal(t, hostSlice.Cap(), 4) + + hostSliceCasted := (HostSlice[internal.MockField])(randFields) + assert.Equal(t, hostSliceCasted.Len(), 4) + assert.Equal(t, hostSliceCasted.Cap(), 4) } func TestHostSliceIsEmpty(t *testing.T) { @@ -190,3 +194,31 @@ func TestCopyToFromHostDeviceProjectivePoints(t *testing.T) { assert.Equal(t, hostSlice, hostSlice2) } + +func TestSliceRanges(t *testing.T) { + var deviceSlice DeviceSlice + + numPoints := 1 << 3 + randProjectives := randomProjectivePoints(numPoints, fieldSize) + hostSlice := (HostSlice[internal.MockProjective])(randProjectives) + hostSlice.CopyToDevice(&deviceSlice, true) + + // RangeFrom + var zeroProj internal.MockProjective + hostSliceRet := HostSliceWithValue[internal.MockProjective](zeroProj, numPoints-2) + + deviceSliceRangeFrom := deviceSlice.RangeFrom(2) + hostSliceRet.CopyFromDevice(&deviceSliceRangeFrom) + assert.Equal(t, hostSlice[2:], hostSliceRet) + + // RangeTo + deviceSliceRangeTo := deviceSlice.RangeTo(numPoints-3, true) + hostSliceRet.CopyFromDevice(&deviceSliceRangeTo) + assert.Equal(t, hostSlice[:6], hostSliceRet) + + // Range + hostSliceRange := HostSliceWithValue[internal.MockProjective](zeroProj, numPoints-4) + deviceSliceRange := deviceSlice.Range(2, numPoints-3, true) + hostSliceRange.CopyFromDevice(&deviceSliceRange) + assert.Equal(t, hostSlice[2:6], hostSliceRange) +}