From 7165a333c356963fbcc5e05648b9bb8b300af0e5 Mon Sep 17 00:00:00 2001 From: archblob Date: Fri, 21 Mar 2014 05:56:53 +0200 Subject: [PATCH] Add sger and dger. --- src/Numerical/HBLAS/BLAS.hs | 9 ++++++++ src/Numerical/HBLAS/BLAS/FFI.hs | 11 ++++++++++ src/Numerical/HBLAS/BLAS/Internal.hs | 33 ++++++++++++++++++++++++++++ tests/UnitBLAS/Level2.hs | 24 +++++++++++++++++++- 4 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/Numerical/HBLAS/BLAS.hs b/src/Numerical/HBLAS/BLAS.hs index 0649549..da1dcb6 100644 --- a/src/Numerical/HBLAS/BLAS.hs +++ b/src/Numerical/HBLAS/BLAS.hs @@ -68,6 +68,9 @@ module Numerical.HBLAS.BLAS( ,cgemv ,zgemv + ,sger + ,dger + ,strsv ,dtrsv ,ctrsv @@ -121,3 +124,9 @@ ctrsv = trsvAbstraction "ctrsv" cblas_ctrsv_safe cblas_ctrsv_unsafe ztrsv :: PrimMonad m => TrsvFun (Complex Double) orient (PrimState m) m ztrsv = trsvAbstraction "ztrsv" cblas_ztrsv_safe cblas_ztrsv_unsafe + +sger :: PrimMonad m => GerFun Float orient (PrimState m) m +sger = gerAbstraction "sger" cblas_sger_safe cblas_sger_unsafe + +dger :: PrimMonad m => GerFun Double orient (PrimState m) m +dger = gerAbstraction "dger" cblas_dger_safe cblas_dger_unsafe diff --git a/src/Numerical/HBLAS/BLAS/FFI.hs b/src/Numerical/HBLAS/BLAS/FFI.hs index b474900..539e713 100644 --- a/src/Numerical/HBLAS/BLAS/FFI.hs +++ b/src/Numerical/HBLAS/BLAS/FFI.hs @@ -270,6 +270,17 @@ foreign import ccall safe "cblas_zgemv" -- perform the rank 1 operation A := alpha*x*y' + A, +type GerFunFFI el = CBLAS_ORDERT -> CInt -> CInt -> el -> Ptr el -> CInt -> Ptr el -> CInt -> Ptr el -> CInt -> IO () + +foreign import ccall unsafe "cblas_sger" cblas_sger_unsafe :: + CBLAS_ORDERT -> CInt -> CInt -> Float -> Ptr Float -> CInt -> Ptr Float -> CInt -> Ptr Float -> CInt -> IO () +foreign import ccall safe "cblas_sger" cblas_sger_safe :: + CBLAS_ORDERT -> CInt -> CInt -> Float -> Ptr Float -> CInt -> Ptr Float -> CInt -> Ptr Float -> CInt -> IO () +foreign import ccall unsafe "cblas_dger" cblas_dger_unsafe :: + CBLAS_ORDERT -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Ptr Double -> CInt -> IO () +foreign import ccall safe "cblas_dger" cblas_dger_safe :: + CBLAS_ORDERT -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Ptr Double -> CInt -> IO () + --void cblas_sger ( enum CBLAS_ORDER order, CInt M, CInt N, Float alpha, Float *X, CInt incX, Float *Y, CInt incY, Float *A, CInt lda); --void cblas_dger ( enum CBLAS_ORDER order, CInt M, CInt N, Double alpha, Double *X, CInt incX, Double *Y, CInt incY, Double *A, CInt lda); --void cblas_cgeru( enum CBLAS_ORDER order, CInt M, CInt N, Float *alpha, Float *X, CInt incX, Float *Y, CInt incY, Float *A, CInt lda); diff --git a/src/Numerical/HBLAS/BLAS/Internal.hs b/src/Numerical/HBLAS/BLAS/Internal.hs index 75dc075..365a977 100644 --- a/src/Numerical/HBLAS/BLAS/Internal.hs +++ b/src/Numerical/HBLAS/BLAS/Internal.hs @@ -3,10 +3,12 @@ module Numerical.HBLAS.BLAS.Internal( GemmFun ,GemvFun + ,GerFun ,TrsvFun ,gemmAbstraction ,gemvAbstraction + ,gerAbstraction ,trsvAbstraction ) where @@ -24,6 +26,8 @@ type GemmFun el orient s m = Transpose ->Transpose -> el -> el -> MDenseMatrix type GemvFun el orient s m = Transpose -> el -> el -> MDenseMatrix s orient el -> MDenseVector s Direct el -> MDenseVector s Direct el -> m () +type GerFun el orient s m = + el -> MDenseVector s Direct el -> MDenseVector s Direct el -> MDenseMatrix s orient el -> m () type TrsvFun el orient s m = MatUpLo -> Transpose -> MatDiag @@ -182,6 +186,35 @@ gemvAbstraction gemvName gemvSafeFFI gemvUnsafeFFI constHandler = gemv (fromIntegral bstride) betaPtr cp (fromIntegral cstride) +{-# NOINLINE gerAbstraction #-} +gerAbstraction :: (SM.Storable el, PrimMonad m) + => String + -> GerFunFFI el + -> GerFunFFI el + -> forall orient . GerFun el orient (PrimState m) m +gerAbstraction gerName gerSafeFFI gerUnsafeFFI = ger + where + shouldCallFast :: Int -> Int -> Bool + shouldCallFast m n = flopsThreshold >= (fromIntegral m :: Int64) + * (fromIntegral n :: Int64) + + isBadGer :: Int -> Int -> Int -> Int -> Bool + isBadGer dx dy ax ay = ax < 0 || ay < 0 || dx < ax || dy < ay + + ger alpha (MutableDenseVector _ xdim xstride xbuff) + (MutableDenseVector _ ydim ystride ybuff) + (MutableDenseMatrix ornta ax ay astride abuff) + | isBadGer xdim ydim ax ay = + error $! "bad dimension args to GER: xdim ydim ax ay" ++ show [xdim, ydim, ax, ay] + | SM.overlaps xbuff abuff || SM.overlaps ybuff abuff = + error $! "The read and write inputs for: " ++ gerName ++ " overlap. This is a programmer error. Please fix." + | otherwise = + unsafeWithPrim xbuff $ \xp -> + unsafeWithPrim ybuff $ \yp -> + unsafeWithPrim abuff $ \ap -> + unsafePrimToPrim $! (if shouldCallFast ax ay then gerUnsafeFFI else gerSafeFFI) + (encodeNiceOrder ornta) (fromIntegral ax) (fromIntegral ay) alpha xp + (fromIntegral xstride) yp (fromIntegral ystride) ap (fromIntegral astride) {-# NOINLINE trsvAbstraction #-} trsvAbstraction :: (SM.Storable el, PrimMonad m) diff --git a/tests/UnitBLAS/Level2.hs b/tests/UnitBLAS/Level2.hs index 2ffa995..f0e22c5 100644 --- a/tests/UnitBLAS/Level2.hs +++ b/tests/UnitBLAS/Level2.hs @@ -66,6 +66,26 @@ matmatTest1ZGEMV = do resList <- Matrix.mutableVectorToList $ _bufferMutDenseVector res resList @?= [2.0,2.0] +---- +---- + +matmatTest1SGER :: IO () +matmatTest1SGER = do + res <- Matrix.generateMutableDenseMatrix (Matrix.SRow) (2,2) (\_ -> 1.0) + x <- Matrix.generateMutableDenseVector 2 (\_ -> 2.0) + y <- Matrix.generateMutableDenseVector 2 (\_ -> 3.0) + BLAS.sger 2.0 x y res + resList <- Matrix.mutableVectorToList $ _bufferDenMutMat res + resList @?= [13.0,13.0,13.0,13.0] + +matmatTest1DGER :: IO () +matmatTest1DGER = do + res <- Matrix.generateMutableDenseMatrix (Matrix.SRow) (2,2) (\_ -> 1.0) + x <- Matrix.generateMutableDenseVector 2 (\_ -> 2.0) + y <- Matrix.generateMutableDenseVector 2 (\_ -> 3.0) + BLAS.sger 2.0 x y res + resList <- Matrix.mutableVectorToList $ _bufferDenMutMat res + resList @?= [13.0,13.0,13.0,13.0] ---- ---- @@ -120,6 +140,8 @@ unitTestLevel2BLAS = testGroup "BLAS Level 2 tests " [ ,testCase "dtrsv on 2x2 upper 1s" matmatTest1DTRSV ,testCase "ctrsv on 2x2 upper 1s" matmatTest1CTRSV ,testCase "ztrsv on 2x2 upper 1s" matmatTest1ZTRSV - +---- ger tests + ,testCase "sger on 2x2 all 1s" matmatTest1SGER + ,testCase "dger on 2x2 all 1s" matmatTest1DGER ]