{- |
If you have a Traversable instance of a record,
you can load and store all elements,
that are accessible by Traversable methods.
In this attempt we support elements of unequal size.
However this can be awfully slow,
since the program might perform size computations at run-time.
-}
module Foreign.Storable.TraversableUnequalSizes (
   alignment, sizeOf,
   peek, poke,
   ) where

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold

import Control.Monad.Trans.State
          (StateT, evalStateT, gets, modify, )
import Control.Monad.IO.Class (liftIO, )

import Foreign.Storable.FixedArray (roundUp, )
import qualified Foreign.Storable as St

import Foreign.Ptr (Ptr, )
import Foreign.Storable (Storable, )


{-# INLINE alignment #-}
alignment ::
   (Fold.Foldable f, Storable a) =>
   f a -> Int
alignment :: f a -> Int
alignment =
   (Int -> a -> Int) -> Int -> f a -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Fold.foldl' (\n :: Int
n x :: a
x -> Int -> Int -> Int
forall a. Integral a => a -> a -> a
lcm Int
n (a -> Int
forall a. Storable a => a -> Int
St.alignment a
x)) 1

{-# INLINE sizeOf #-}
sizeOf ::
   (Fold.Foldable f, Storable a) =>
   f a -> Int
sizeOf :: f a -> Int
sizeOf f :: f a
f =
   Int -> Int -> Int
roundUp (f a -> Int
forall (f :: * -> *) a. (Foldable f, Storable a) => f a -> Int
alignment f a
f) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$
   (Int -> a -> Int) -> Int -> f a -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Fold.foldl' (\s :: Int
s x :: a
x -> Int -> Int -> Int
roundUp (a -> Int
forall a. Storable a => a -> Int
St.alignment a
x) Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a. Storable a => a -> Int
St.sizeOf a
x) 0 f a
f

{-
This function requires that alignment does not depend on an element value,
because we cannot not know the value before loading it.
Thus @alignment (undefined::a)@ must be defined.
-}
{-# INLINE peek #-}
peek ::
   (Trav.Traversable f, Storable a) =>
   f () -> Ptr (f a) -> IO (f a)
peek :: f () -> Ptr (f a) -> IO (f a)
peek skeleton :: f ()
skeleton ptr :: Ptr (f a)
ptr =
   StateT Int IO (f a) -> Int -> IO (f a)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((() -> StateT Int IO a) -> f () -> StateT Int IO (f a)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
Trav.mapM (StateT Int IO a -> () -> StateT Int IO a
forall a b. a -> b -> a
const (Ptr (f a) -> StateT Int IO a
forall a (f :: * -> *). Storable a => Ptr (f a) -> StateT Int IO a
peekState Ptr (f a)
ptr)) f ()
skeleton) 0

{-# INLINE peekState #-}
peekState ::
   (Storable a) =>
   Ptr (f a) -> StateT Int IO a
peekState :: Ptr (f a) -> StateT Int IO a
peekState p :: Ptr (f a)
p = do
   let pseudoPeek :: Ptr (f a) -> a
       pseudoPeek :: Ptr (f a) -> a
pseudoPeek _ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error "Traversable.peek: alignment must not depend on the element value"
   Int
k <- a -> StateT Int IO Int
forall a. Storable a => a -> StateT Int IO Int
getOffset (Ptr (f a) -> a
forall (f :: * -> *) a. Ptr (f a) -> a
pseudoPeek Ptr (f a)
p)
   a
a <- IO a -> StateT Int IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr (f a) -> Int -> IO a
forall a b. Storable a => Ptr b -> Int -> IO a
St.peekByteOff Ptr (f a)
p Int
k)
   a -> StateT Int IO ()
forall a. Storable a => a -> StateT Int IO ()
advanceOffset a
a
   a -> StateT Int IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

{-# INLINE poke #-}
poke ::
   (Fold.Foldable f, Storable a) =>
   Ptr (f a) -> f a -> IO ()
poke :: Ptr (f a) -> f a -> IO ()
poke ptr :: Ptr (f a)
ptr x :: f a
x =
   StateT Int IO () -> Int -> IO ()
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((a -> StateT Int IO ()) -> f a -> StateT Int IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
Fold.traverse_ (Ptr (f a) -> a -> StateT Int IO ()
forall a (f :: * -> *).
Storable a =>
Ptr (f a) -> a -> StateT Int IO ()
pokeState Ptr (f a)
ptr) f a
x) 0

{-# INLINE pokeState #-}
pokeState ::
   (Storable a) =>
   Ptr (f a) -> a -> StateT Int IO ()
pokeState :: Ptr (f a) -> a -> StateT Int IO ()
pokeState p :: Ptr (f a)
p a :: a
a = do
   Int
k <- a -> StateT Int IO Int
forall a. Storable a => a -> StateT Int IO Int
getOffset a
a
   IO () -> StateT Int IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Ptr (f a) -> Int -> a -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
St.pokeByteOff Ptr (f a)
p Int
k a
a)
   a -> StateT Int IO ()
forall a. Storable a => a -> StateT Int IO ()
advanceOffset a
a

{-# INLINE getOffset #-}
getOffset ::
   (Storable a) =>
   a -> StateT Int IO Int
getOffset :: a -> StateT Int IO Int
getOffset a :: a
a =
   (Int -> Int) -> StateT Int IO Int
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Int -> Int -> Int
roundUp (a -> Int
forall a. Storable a => a -> Int
St.alignment a
a))

{-# INLINE advanceOffset #-}
advanceOffset ::
   (Storable a) =>
   a -> StateT Int IO ()
advanceOffset :: a -> StateT Int IO ()
advanceOffset a :: a
a =
   (Int -> Int) -> StateT Int IO ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ( Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a. Storable a => a -> Int
St.sizeOf a
a)