Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sger and dger. #18

Merged
merged 1 commit into from
Feb 2, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Numerical/HBLAS/BLAS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ module Numerical.HBLAS.BLAS(
,cgemv
,zgemv

,sger
,dger

,strsv
,dtrsv
,ctrsv
Expand Down Expand Up @@ -121,3 +124,9 @@ ctrsv = trsvAbstraction "ctrsv" cblas_ctrsv_safe cblas_ctrsv_unsafe

ztrsv :: PrimMonad m => TrsvFun (Complex Double) orient (PrimState m) m
ztrsv = trsvAbstraction "ztrsv" cblas_ztrsv_safe cblas_ztrsv_unsafe

sger :: PrimMonad m => GerFun Float orient (PrimState m) m
sger = gerAbstraction "sger" cblas_sger_safe cblas_sger_unsafe

dger :: PrimMonad m => GerFun Double orient (PrimState m) m
dger = gerAbstraction "dger" cblas_dger_safe cblas_dger_unsafe
11 changes: 11 additions & 0 deletions src/Numerical/HBLAS/BLAS/FFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,17 @@ foreign import ccall safe "cblas_zgemv"

-- perform the rank 1 operation A := alpha*x*y' + A,

type GerFunFFI el = CBLAS_ORDERT -> CInt -> CInt -> el -> Ptr el -> CInt -> Ptr el -> CInt -> Ptr el -> CInt -> IO ()

foreign import ccall unsafe "cblas_sger" cblas_sger_unsafe ::
CBLAS_ORDERT -> CInt -> CInt -> Float -> Ptr Float -> CInt -> Ptr Float -> CInt -> Ptr Float -> CInt -> IO ()
foreign import ccall safe "cblas_sger" cblas_sger_safe ::
CBLAS_ORDERT -> CInt -> CInt -> Float -> Ptr Float -> CInt -> Ptr Float -> CInt -> Ptr Float -> CInt -> IO ()
foreign import ccall unsafe "cblas_dger" cblas_dger_unsafe ::
CBLAS_ORDERT -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Ptr Double -> CInt -> IO ()
foreign import ccall safe "cblas_dger" cblas_dger_safe ::
CBLAS_ORDERT -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Ptr Double -> CInt -> IO ()

--void cblas_sger ( enum CBLAS_ORDER order, CInt M, CInt N, Float alpha, Float *X, CInt incX, Float *Y, CInt incY, Float *A, CInt lda);
--void cblas_dger ( enum CBLAS_ORDER order, CInt M, CInt N, Double alpha, Double *X, CInt incX, Double *Y, CInt incY, Double *A, CInt lda);
--void cblas_cgeru( enum CBLAS_ORDER order, CInt M, CInt N, Float *alpha, Float *X, CInt incX, Float *Y, CInt incY, Float *A, CInt lda);
Expand Down
33 changes: 33 additions & 0 deletions src/Numerical/HBLAS/BLAS/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
module Numerical.HBLAS.BLAS.Internal(
GemmFun
,GemvFun
,GerFun
,TrsvFun

,gemmAbstraction
,gemvAbstraction
,gerAbstraction
,trsvAbstraction
) where

Expand All @@ -24,6 +26,8 @@ type GemmFun el orient s m = Transpose ->Transpose -> el -> el -> MDenseMatrix
type GemvFun el orient s m = Transpose -> el -> el
-> MDenseMatrix s orient el -> MDenseVector s Direct el -> MDenseVector s Direct el -> m ()

type GerFun el orient s m =
el -> MDenseVector s Direct el -> MDenseVector s Direct el -> MDenseMatrix s orient el -> m ()

type TrsvFun el orient s m =
MatUpLo -> Transpose -> MatDiag
Expand Down Expand Up @@ -182,6 +186,35 @@ gemvAbstraction gemvName gemvSafeFFI gemvUnsafeFFI constHandler = gemv
(fromIntegral bstride) betaPtr cp (fromIntegral cstride)


{-# NOINLINE gerAbstraction #-}
gerAbstraction :: (SM.Storable el, PrimMonad m)
=> String
-> GerFunFFI el
-> GerFunFFI el
-> forall orient . GerFun el orient (PrimState m) m
gerAbstraction gerName gerSafeFFI gerUnsafeFFI = ger
where
shouldCallFast :: Int -> Int -> Bool
shouldCallFast m n = flopsThreshold >= (fromIntegral m :: Int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@archblob derp, should this have been a <= instead? ie we should be calling the fast code IF m*n is LESS than flopsThreshold? shows what I get for merging a PR before doing a close reading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't really looked at the code since the pr was done almost a year ago :-P. I just rebased and added some tests. Will take a look when I get home.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we should be calling the fast code IF mn is LESS(or LEQ) than flopsThreshold. But I think it is the same that we should be calling the fast code IF flopsThreshold is GREATER(or GEQ) than mn?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct. It is correct as written. Fast is true when the product is
less than the threshold. Probably I should write a little helper
function to make it less confusing
On Thursday, May 7, 2015, Jueji Yang [email protected] wrote:

In src/Numerical/HBLAS/BLAS/Internal.hs
#18 (comment):

@@ -182,6 +186,35 @@ gemvAbstraction gemvName gemvSafeFFI gemvUnsafeFFI constHandler = gemv
(fromIntegral bstride) betaPtr cp (fromIntegral cstride)

+{-# NOINLINE gerAbstraction #-}
+gerAbstraction :: (SM.Storable el, PrimMonad m)

  •           => String
    
  •           -> GerFunFFI el
    
  •           -> GerFunFFI el
    
  •           -> forall orient . GerFun el orient (PrimState m) m
    
    +gerAbstraction gerName gerSafeFFI gerUnsafeFFI = ger
  • where
  •  shouldCallFast :: Int -> Int -> Bool
    
  •  shouldCallFast m n = flopsThreshold >= (fromIntegral m :: Int64)
    

Yes, I think we should be calling the fast code IF m_n is LESS_(or LEQ)
than flopsThreshold. But I think it is the same that we should be calling
the fast code IF flopsThreshold is GREATER(or GEQ) than m
n?


Reply to this email directly or view it on GitHub
https://github.com/wellposed/hblas/pull/18/files#r29869978.

* (fromIntegral n :: Int64)

isBadGer :: Int -> Int -> Int -> Int -> Bool
isBadGer dx dy ax ay = ax < 0 || ay < 0 || dx < ax || dy < ay

ger alpha (MutableDenseVector _ xdim xstride xbuff)
(MutableDenseVector _ ydim ystride ybuff)
(MutableDenseMatrix ornta ax ay astride abuff)
| isBadGer xdim ydim ax ay =
error $! "bad dimension args to GER: xdim ydim ax ay" ++ show [xdim, ydim, ax, ay]
| SM.overlaps xbuff abuff || SM.overlaps ybuff abuff =
error $! "The read and write inputs for: " ++ gerName ++ " overlap. This is a programmer error. Please fix."
| otherwise =
unsafeWithPrim xbuff $ \xp ->
unsafeWithPrim ybuff $ \yp ->
unsafeWithPrim abuff $ \ap ->
unsafePrimToPrim $! (if shouldCallFast ax ay then gerUnsafeFFI else gerSafeFFI)
(encodeNiceOrder ornta) (fromIntegral ax) (fromIntegral ay) alpha xp
(fromIntegral xstride) yp (fromIntegral ystride) ap (fromIntegral astride)

{-# NOINLINE trsvAbstraction #-}
trsvAbstraction :: (SM.Storable el, PrimMonad m)
Expand Down
24 changes: 23 additions & 1 deletion tests/UnitBLAS/Level2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ matmatTest1ZGEMV = do
resList <- Matrix.mutableVectorToList $ _bufferMutDenseVector res
resList @?= [2.0,2.0]

----
----

matmatTest1SGER :: IO ()
matmatTest1SGER = do
res <- Matrix.generateMutableDenseMatrix (Matrix.SRow) (2,2) (\_ -> 1.0)
x <- Matrix.generateMutableDenseVector 2 (\_ -> 2.0)
y <- Matrix.generateMutableDenseVector 2 (\_ -> 3.0)
BLAS.sger 2.0 x y res
resList <- Matrix.mutableVectorToList $ _bufferDenMutMat res
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to add a wee comment explaining
A := alpha_x_y' + A, with the number filled in. So that why this test is correct is more immediately obvious.

Which does raise the realted point that we should probably put those formulae in the haddcoks somewhhere, but probably not in the tests or this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eg a= 2* (2*3) +1= 13

resList @?= [13.0,13.0,13.0,13.0]

matmatTest1DGER :: IO ()
matmatTest1DGER = do
res <- Matrix.generateMutableDenseMatrix (Matrix.SRow) (2,2) (\_ -> 1.0)
x <- Matrix.generateMutableDenseVector 2 (\_ -> 2.0)
y <- Matrix.generateMutableDenseVector 2 (\_ -> 3.0)
BLAS.sger 2.0 x y res
resList <- Matrix.mutableVectorToList $ _bufferDenMutMat res
resList @?= [13.0,13.0,13.0,13.0]

----
----
Expand Down Expand Up @@ -120,6 +140,8 @@ unitTestLevel2BLAS = testGroup "BLAS Level 2 tests " [
,testCase "dtrsv on 2x2 upper 1s" matmatTest1DTRSV
,testCase "ctrsv on 2x2 upper 1s" matmatTest1CTRSV
,testCase "ztrsv on 2x2 upper 1s" matmatTest1ZTRSV

---- ger tests
,testCase "sger on 2x2 all 1s" matmatTest1SGER
,testCase "dger on 2x2 all 1s" matmatTest1DGER
]