Commit dcb87ef9 authored by Hans-Peter Deifel's avatar Hans-Peter Deifel 🐢
Browse files

wta: Fix whitespace and refactor generator

parent 269a5561
Loading
Loading
Loading
Loading
+55 −28
Original line number Diff line number Diff line
@@ -2,7 +2,13 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}

module Generator (genWTA, runGenerator, GeneratorConfig(..), EdgeConfig(..)) where
module Generator
  ( genWTA
  , runGenerator
  , GeneratorConfig(..)
  , EdgeConfig(..)
  )
where

import           Data.Vector                    ( Vector )
import qualified Data.Vector                   as V
@@ -93,27 +99,35 @@ genStateTransitions = do
  arities <- asks (numSymbols . spec)
  fold <$> traverse genForArity (V.findIndices (/= 0) arities)

genTransitions :: Generator m (Vector (Vector (Transition m)))
genTransitions = do
genTransitionsZeroFreq :: Generator m (Vector (Vector (Transition m)))
genTransitionsZeroFreq = do
  n <- asks (numStates . spec)
  V.replicateM n genStateTransitions

-- TODO Implement (Random IndexedTransition)
uniqueTransitions :: Int -> IndexedTransition -> IO [IndexedTransition]
uniqueTransitions num idxMax@(IndexedTransition.Index max)
  | fromIntegral num < fromIntegral max * (7%10) = uniqueTransitionsByGeneration num idxMax
  | fromIntegral num < fromIntegral max * (7 % 10) = uniqueTransitionsByGeneration
    num
    idxMax
  | otherwise = uniqueTransitionsByElimination num idxMax

uniqueTransitionsByGeneration :: Int -> IndexedTransition -> IO [IndexedTransition]
uniqueTransitionsByGeneration num (IndexedTransition.Index max) = helper S.empty num
uniqueTransitionsByGeneration
  :: Int -> IndexedTransition -> IO [IndexedTransition]
uniqueTransitionsByGeneration num (IndexedTransition.Index max) = helper
  S.empty
  num
  where
    helper m 0 = return $ coerce (S.toList m)
    helper m c = do
      x <- randomRIO (0, max - 1)
      if x `S.member` m then helper m c else helper (S.insert x m) (c - 1)

uniqueTransitionsByElimination :: Int -> IndexedTransition -> IO [IndexedTransition]
uniqueTransitionsByElimination num (IndexedTransition.Index max) = helper whole num
uniqueTransitionsByElimination
  :: Int -> IndexedTransition -> IO [IndexedTransition]
uniqueTransitionsByElimination num (IndexedTransition.Index max) = helper
  whole
  num
  where
    helper free 0 = return $ coerce (S.toList (S.difference whole free))
    helper free c = do
@@ -123,23 +137,35 @@ uniqueTransitionsByElimination num (IndexedTransition.Index max) = helper whole

    whole = S.fromList [0 .. max - 1]

genTransitions' :: Int -> Generator m (Vector (Vector (Transition m)))
genTransitions' numTransitions = do
genTransitionsNumTrans :: Int -> Generator m (Vector (Vector (Transition m)))
genTransitionsNumTrans numTransitions = do
  wtaSpec <- asks spec
  let n    = numStates wtaSpec
      maxT = IndexedTransition.maxIndex wtaSpec

  numTransitions' <- if (fromIntegral numTransitions > maxT) then do
  numTransitions' <- if (fromIntegral numTransitions > maxT)
    then do
      let cap = IndexedTransition.fromIndexd maxT
      lift $ hPutStrLn stderr ("warning: More transitions than possible requested. Capping at " <> show cap)
      lift $ hPutStrLn
        stderr
        (  "warning: More transitions than possible requested. Capping at "
        <> show cap
        )
      return (fromIntegral cap)
    else do
      return numTransitions

  transitions <- lift $ map (IndexedTransition.fromIndex wtaSpec) <$> uniqueTransitions numTransitions' maxT
  weightedTransitions <- (traverse.traverse.traverse) (const genMonoidValue) transitions
  transitions <-
    lift
    $   map (IndexedTransition.fromIndex wtaSpec)
    <$> uniqueTransitions numTransitions' maxT
  weightedTransitions <- (traverse . traverse . traverse)
    (const genMonoidValue)
    transitions

  let byState = foldl' (\m (State s, t) -> M.insertWith (++) s [t] m) M.empty weightedTransitions
  let byState = foldl' (\m (State s, t) -> M.insertWith (++) s [t] m)
                       M.empty
                       weightedTransitions

  let stateVec = V.generate n $ \i -> case M.lookup i byState of
        Nothing  -> V.empty
@@ -149,5 +175,6 @@ genTransitions' numTransitions = do

genWTA :: Generator m (WTA m)
genWTA = asks zeroPolicy >>= \case
  NumTransitions d -> WTA <$> asks spec <*> genStates <*> (genTransitions' d)
  ZeroFrequency _ -> WTA <$> asks spec <*> genStates <*> genTransitions
  NumTransitions d ->
    WTA <$> asks spec <*> genStates <*> (genTransitionsNumTrans d)
  ZeroFrequency _ -> WTA <$> asks spec <*> genStates <*> genTransitionsZeroFreq