{-# LANGUAGE ExistentialQuantification, FlexibleContexts, TypeOperators #-}

module Test.IOSpec.STM
   (
   -- * The specification of STM
     STMS
   -- * Atomically
   , atomically
   -- * The STM monad
   , STM
   , TVar
   , newTVar
   , readTVar
   , writeTVar
   , retry
   , orElse
   , check
   )
   where

import Test.IOSpec.VirtualMachine
import Test.IOSpec.Types
import Data.Dynamic
import Data.Maybe (fromJust)
import Control.Monad.State
import Control.Monad (ap)

-- The 'STMS' data type and its instances.
--
-- | An expression of type @IOSpec 'STMS' a@ corresponds to an 'IO'
-- computation that may use 'atomically' and returns a value of type
-- @a@.
--
-- By itself, 'STMS' is not terribly useful. You will probably want
-- to use @IOSpec (ForkS :+: STMS)@.
data STMS a =
  forall b . Atomically (STM b) (b -> a)

instance Functor STMS where
  fmap :: forall a b. (a -> b) -> STMS a -> STMS b
fmap a -> b
f (Atomically STM b
s b -> a
io) = STM b -> (b -> b) -> STMS b
forall a b. STM b -> (b -> a) -> STMS a
Atomically STM b
s (a -> b
f (a -> b) -> (b -> a) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
io)

-- | The 'atomically' function atomically executes an 'STM' action.
atomically     :: (STMS :<: f) => STM a -> IOSpec f a
atomically :: forall (f :: * -> *) a. (STMS :<: f) => STM a -> IOSpec f a
atomically STM a
stm = STMS (IOSpec f a) -> IOSpec f a
forall (g :: * -> *) (f :: * -> *) a.
(g :<: f) =>
g (IOSpec f a) -> IOSpec f a
inject (STMS (IOSpec f a) -> IOSpec f a)
-> STMS (IOSpec f a) -> IOSpec f a
forall a b. (a -> b) -> a -> b
$ STM a -> (a -> IOSpec f a) -> STMS (IOSpec f a)
forall a b. STM b -> (b -> a) -> STMS a
Atomically STM a
stm (a -> IOSpec f a
forall a. a -> IOSpec f a
forall (m :: * -> *) a. Monad m => a -> m a
return)

instance Executable STMS where
  step :: forall a. STMS a -> VM (Step a)
step (Atomically STM b
stm b -> a
b) =
    do Store
state <- StateT Store Effect Store
forall s (m :: * -> *). MonadState s m => m s
get
       case StateT Store Effect (Maybe b) -> Store -> Effect (Maybe b, Store)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (STM b -> StateT Store Effect (Maybe b)
forall a. STM a -> VM (Maybe a)
executeSTM STM b
stm) Store
state of
         Done (Maybe b
Nothing,Store
_)         -> Step a -> VM (Step a)
forall a. a -> StateT Store Effect a
forall (m :: * -> *) a. Monad m => a -> m a
return Step a
forall a. Step a
Block
         Done (Just b
x,Store
finalState) -> Store -> StateT Store Effect ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put Store
finalState StateT Store Effect () -> VM (Step a) -> VM (Step a)
forall a b.
StateT Store Effect a
-> StateT Store Effect b -> StateT Store Effect b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Step a -> VM (Step a)
forall a. a -> StateT Store Effect a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Step a
forall a. a -> Step a
Step (b -> a
b b
x))
         Effect (Maybe b, Store)
_                        -> String -> VM (Step a)
forall a. String -> a
internalError String
"Unsafe usage of STM"

-- The 'STM' data type and its instances.
data STM a =
    STMReturn a
  | NewTVar Data (Loc -> STM a)
  | ReadTVar Loc (Data -> STM a)
  | WriteTVar Loc Data (STM a)
  | Retry
  | OrElse (STM a) (STM a)

instance Functor STM where
  fmap :: forall a b. (a -> b) -> STM a -> STM b
fmap a -> b
f (STMReturn a
x)      = b -> STM b
forall a. a -> STM a
STMReturn (a -> b
f a
x)
  fmap a -> b
f (NewTVar Data
d Loc -> STM a
io)     = Data -> (Loc -> STM b) -> STM b
forall a. Data -> (Loc -> STM a) -> STM a
NewTVar Data
d ((a -> b) -> STM a -> STM b
forall a b. (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f (STM a -> STM b) -> (Loc -> STM a) -> Loc -> STM b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Loc -> STM a
io)
  fmap a -> b
f (ReadTVar Loc
l Data -> STM a
io)    = Loc -> (Data -> STM b) -> STM b
forall a. Loc -> (Data -> STM a) -> STM a
ReadTVar Loc
l ((a -> b) -> STM a -> STM b
forall a b. (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f (STM a -> STM b) -> (Data -> STM a) -> Data -> STM b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Data -> STM a
io)
  fmap a -> b
f (WriteTVar Loc
l Data
d STM a
io) = Loc -> Data -> STM b -> STM b
forall a. Loc -> Data -> STM a -> STM a
WriteTVar Loc
l Data
d ((a -> b) -> STM a -> STM b
forall a b. (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f STM a
io)
  fmap a -> b
_ STM a
Retry              = STM b
forall a. STM a
Retry
  fmap a -> b
f (OrElse STM a
io1 STM a
io2)   = STM b -> STM b -> STM b
forall a. STM a -> STM a -> STM a
OrElse ((a -> b) -> STM a -> STM b
forall a b. (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f STM a
io1) ((a -> b) -> STM a -> STM b
forall a b. (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f STM a
io2)

instance Applicative STM where
  pure :: forall a. a -> STM a
pure  = a -> STM a
forall a. a -> STM a
STMReturn
  <*> :: forall a b. STM (a -> b) -> STM a -> STM b
(<*>) = STM (a -> b) -> STM a -> STM b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad STM where
  return :: forall a. a -> STM a
return                = a -> STM a
forall a. a -> STM a
STMReturn
  STMReturn a
a >>= :: forall a b. STM a -> (a -> STM b) -> STM b
>>= a -> STM b
f     = a -> STM b
f a
a
  NewTVar Data
d Loc -> STM a
g >>= a -> STM b
f     = Data -> (Loc -> STM b) -> STM b
forall a. Data -> (Loc -> STM a) -> STM a
NewTVar Data
d (\Loc
l -> Loc -> STM a
g Loc
l STM a -> (a -> STM b) -> STM b
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)
  ReadTVar Loc
l Data -> STM a
g >>= a -> STM b
f    = Loc -> (Data -> STM b) -> STM b
forall a. Loc -> (Data -> STM a) -> STM a
ReadTVar Loc
l (\Data
d -> Data -> STM a
g Data
d STM a -> (a -> STM b) -> STM b
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)
  WriteTVar Loc
l Data
d STM a
p >>= a -> STM b
f = Loc -> Data -> STM b -> STM b
forall a. Loc -> Data -> STM a -> STM a
WriteTVar Loc
l Data
d (STM a
p STM a -> (a -> STM b) -> STM b
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)
  STM a
Retry >>= a -> STM b
_           = STM b
forall a. STM a
Retry
  OrElse STM a
p STM a
q >>= a -> STM b
f      = STM b -> STM b -> STM b
forall a. STM a -> STM a -> STM a
OrElse (STM a
p STM a -> (a -> STM b) -> STM b
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f) (STM a
q STM a -> (a -> STM b) -> STM b
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)

-- | A 'TVar' is a shared, mutable variable used by STM.
newtype TVar a = TVar Loc

-- | The 'newTVar' function creates a new transactional variable.
newTVar   :: Typeable a => a -> STM (TVar a)
newTVar :: forall a. Typeable a => a -> STM (TVar a)
newTVar a
d = Data -> (Loc -> STM (TVar a)) -> STM (TVar a)
forall a. Data -> (Loc -> STM a) -> STM a
NewTVar (a -> Data
forall a. Typeable a => a -> Data
toDyn a
d) (TVar a -> STM (TVar a)
forall a. a -> STM a
STMReturn (TVar a -> STM (TVar a)) -> (Loc -> TVar a) -> Loc -> STM (TVar a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Loc -> TVar a
forall a. Loc -> TVar a
TVar)

-- | The 'readTVar' function reads the value stored in a
-- transactional variable.
readTVar          :: Typeable a => TVar a -> STM a
readTVar :: forall a. Typeable a => TVar a -> STM a
readTVar (TVar Loc
l) = Loc -> (Data -> STM a) -> STM a
forall a. Loc -> (Data -> STM a) -> STM a
ReadTVar Loc
l (a -> STM a
forall a. a -> STM a
STMReturn (a -> STM a) -> (Data -> a) -> Data -> STM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> (Data -> Maybe a) -> Data -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Data -> Maybe a
forall a. Typeable a => Data -> Maybe a
fromDynamic)

-- | The 'writeTVar' function overwrites the value stored in a
-- transactional variable.
writeTVar            :: Typeable a => TVar a -> a -> STM ()
writeTVar :: forall a. Typeable a => TVar a -> a -> STM ()
writeTVar (TVar Loc
l) a
d = Loc -> Data -> STM () -> STM ()
forall a. Loc -> Data -> STM a -> STM a
WriteTVar Loc
l (a -> Data
forall a. Typeable a => a -> Data
toDyn a
d) (() -> STM ()
forall a. a -> STM a
STMReturn ())

-- | The 'retry' function abandons a transaction and retries at some
-- later time.
retry :: STM a
retry :: forall a. STM a
retry = STM a
forall a. STM a
Retry

-- | The 'check' function checks if its boolean argument holds. If
-- the boolean is true, it returns (); otherwise it calls 'retry'.
check       :: Bool -> STM ()
check :: Bool -> STM ()
check Bool
True  = () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Bool
False = STM ()
forall a. STM a
retry

-- | The 'orElse' function takes two 'STM' actions @stm1@ and @stm2@ and
-- performs @stm1@. If @stm1@ calls 'retry' it performs @stm2@. If @stm1@
-- succeeds, on the other hand, @stm2@ is not executed.
orElse     :: STM a -> STM a -> STM a
orElse :: forall a. STM a -> STM a -> STM a
orElse STM a
p STM a
q = STM a -> STM a -> STM a
forall a. STM a -> STM a -> STM a
OrElse STM a
p STM a
q

executeSTM :: STM a -> VM (Maybe a)
executeSTM :: forall a. STM a -> VM (Maybe a)
executeSTM (STMReturn a
x)      = Maybe a -> StateT Store Effect (Maybe a)
forall a. a -> StateT Store Effect a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x)
executeSTM (NewTVar Data
d Loc -> STM a
io)     = do
  Loc
loc <- VM Loc
alloc
  Loc -> Data -> StateT Store Effect ()
updateHeap Loc
loc Data
d
  STM a -> StateT Store Effect (Maybe a)
forall a. STM a -> VM (Maybe a)
executeSTM (Loc -> STM a
io Loc
loc)
executeSTM (ReadTVar Loc
l Data -> STM a
io)    = do
  Loc -> VM (Maybe Data)
lookupHeap Loc
l VM (Maybe Data)
-> (Maybe Data -> StateT Store Effect (Maybe a))
-> StateT Store Effect (Maybe a)
forall a b.
StateT Store Effect a
-> (a -> StateT Store Effect b) -> StateT Store Effect b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Just Data
d) -> do
  STM a -> StateT Store Effect (Maybe a)
forall a. STM a -> VM (Maybe a)
executeSTM (Data -> STM a
io Data
d)
executeSTM (WriteTVar Loc
l Data
d STM a
io) = do
  Loc -> Data -> StateT Store Effect ()
updateHeap Loc
l Data
d
  STM a -> StateT Store Effect (Maybe a)
forall a. STM a -> VM (Maybe a)
executeSTM STM a
io
executeSTM STM a
Retry              = Maybe a -> StateT Store Effect (Maybe a)
forall a. a -> StateT Store Effect a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
executeSTM (OrElse STM a
p STM a
q)       = do
  Store
state <- StateT Store Effect Store
forall s (m :: * -> *). MonadState s m => m s
get
  case StateT Store Effect (Maybe a) -> Store -> Effect (Maybe a, Store)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (STM a -> StateT Store Effect (Maybe a)
forall a. STM a -> VM (Maybe a)
executeSTM STM a
p) Store
state of
    Done (Maybe a
Nothing,Store
_) -> STM a -> StateT Store Effect (Maybe a)
forall a. STM a -> VM (Maybe a)
executeSTM STM a
q
    Done (Just a
x,Store
s)  -> Store -> StateT Store Effect ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put Store
s StateT Store Effect ()
-> StateT Store Effect (Maybe a) -> StateT Store Effect (Maybe a)
forall a b.
StateT Store Effect a
-> StateT Store Effect b -> StateT Store Effect b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe a -> StateT Store Effect (Maybe a)
forall a. a -> StateT Store Effect a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
x)
    Effect (Maybe a, Store)
_                -> String -> StateT Store Effect (Maybe a)
forall a. String -> a
internalError String
"Unsafe usage of STM"

internalError :: String -> a
internalError :: forall a. String -> a
internalError String
msg = String -> a
forall a. HasCallStack => String -> a
error (String
"IOSpec.STM: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)