diff --git a/haskell/scratch/Monadic.hs b/haskell/scratch/Monadic.hs index 92cc3f0..78ab806 100644 --- a/haskell/scratch/Monadic.hs +++ b/haskell/scratch/Monadic.hs @@ -6,21 +6,34 @@ data MyMonad a = DivByZeroError String | NotDivisible Int | Divisible a deriving (Show) -- This is where we define fmap (i.e. map but for monads) +-- Essentially, this is where we define what happens when we apply a function to a monad +-- i.e.: +-- fmap (+1) (Divisible 3) => Divisible 4 +-- fmap (+1) (NotDivisible 3) => NotDivisible 3 +-- fmap (+1) (DivByZeroError "error") => DivByZeroError "error" instance Functor MyMonad where - fmap f (DivByZeroError s) = DivByZeroError s - fmap f (Divisible a) = Divisible (f a) - fmap f (NotDivisible a) = NotDivisible a + fmap f (DivByZeroError s) = DivByZeroError s -- Note the error + fmap f (Divisible a) = Divisible (f a) -- Apply the function f to the value a + fmap f (NotDivisible a) = NotDivisible a -- Not divisible, propagate the value forward -- This is where we define <*> (i.e. apply but for monads) essentially applying the -- "wrapped" function to the "wrapped" value +-- i.e.: +-- (Divisible (+1)) <*> (Divisible 3) => Divisible 4 +-- (Divisible (+1)) <*> (NotDivisible 3) => NotDivisible 3 +-- (Divisible (+1)) <*> (DivByZeroError "error") => DivByZeroError "error" instance Applicative MyMonad where - (DivByZeroError s) <*> _ = DivByZeroError s - (Divisible a) <*> b = fmap a b - (NotDivisible a) <*> b = NotDivisible a - pure = Divisible + (DivByZeroError s) <*> _ = DivByZeroError s -- Note the error + (Divisible a) <*> b = fmap a b -- Apply the function a to the value b + (NotDivisible a) <*> b = NotDivisible a -- Not divisible, propagate the value forward + pure = Divisible -- Wrap the value in the Divisible constructor -- This is where we define >>=, the bind operator, which chains monadic operations --- i.e. in this case, we chain the operations of dividing and checking if the number is divisible +-- i.e. in this case, we chain: +-- tryDivide 2 420 >>= tryDivide 3 >>= tryDivide 5 >>= tryDivide 9 >>= tryDivide 0 +-- If we encounter a DivByZeroError, we propagate it forward (i.e. we don't do anything) +-- If we encounter a NotDivisible, we propagate it forward +-- If we encounter a Divisible, we apply the function f to the value x instance Monad MyMonad where DivByZeroError msg >>= _ = DivByZeroError msg -- if we encounter a DivByZeroError, we propagate it forward NotDivisible n >>= _ = NotDivisible n -- if we encounter a NotDivisible, we propagate it forward @@ -34,11 +47,11 @@ tryDivide d n | n `mod` d == 0 = Divisible (n `div` d) | otherwise = NotDivisible n --- Process the divisors (map out the results of tryDivide) +-- Sequentially divide the number n by the divisors in the list sequentialDividing :: [Int] -> Int -> MyMonad Int sequentialDividing [] n = return n sequentialDividing (d:ds) n = do - newN <- tryDivide d n + newN <- tryDivide d n -- Note how we're not doing any pattern matching here, we're just applying the function, and the monad takes care of the rest sequentialDividing ds newN main :: IO ()