Commit 3af243c5 authored by Hans-Peter Deifel's avatar Hans-Peter Deifel 🐢
Browse files

wta: Change generation algorithm to accommodate sparse graphs

The current way of iterating trough all possible edges and deciding for each
edge if we take it or not has proven not to scale to large but sparse graphs.

Instead, we now generate the edges that we want directly.
parent 5c5d3534
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -290,6 +290,7 @@ executable random-wta
                     , Generator
                     , Output
                     , Probability
                     , IndexedTransition
  default-language:    Haskell2010
  default-extensions:  OverloadedStrings
                     , LambdaCase
@@ -304,3 +305,4 @@ executable random-wta
                     , mtl >= 2.2 && <2.3
                     , megaparsec >= 7 && <8
                     , scientific >= 0.3 && <0.4
                     , containers >= 0.6 && <0.7
+44 −3
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}

module Generator (genWTA, runGenerator, GeneratorConfig(..)) where
module Generator (genWTA, runGenerator, GeneratorConfig(..), ZeroFrequency(..)) where

import           Data.Vector                    ( Vector )
import qualified Data.Vector                   as V
@@ -12,18 +12,29 @@ import Data.Coerce
import           Data.Maybe
import           Data.Foldable
import           Control.Arrow                  ( (&&&) )
import qualified Data.IntMap.Strict            as M
import qualified Data.IntSet            as S
import           Data.Coerce

import           Types                   hiding ( spec )
import           Probability
import           IndexedTransition
import qualified IndexedTransition

data ZeroFrequency = Percentage Probability | OutDegree Int

data GeneratorConfig m = GeneratorConfig
   { spec :: WTASpec m
   , zeroFreq :: Probability
   , zeroPolicy :: ZeroFrequency
   , differentValues :: Maybe Int
   }

type Generator m = ReaderT (GeneratorConfig m) IO

zeroFreq :: GeneratorConfig m -> Probability
zeroFreq (GeneratorConfig { zeroPolicy = Percentage p }) = p
zeroFreq _ = error "zeroFreq: unexpected out degree" -- TODO Ugly as hell

runGenerator :: GeneratorConfig m -> Generator m a -> IO a
runGenerator config action = runReaderT action config

@@ -85,5 +96,35 @@ genTransitions = do
  n <- asks (numStates . spec)
  V.replicateM n genStateTransitions

-- TODO Implement (Random IndexedTransition)
uniqueTransitions :: Int -> IndexedTransition -> IO [IndexedTransition]
uniqueTransitions num (IndexedTransition.Index max) = helper S.empty num
  where
    helper m 0 = return $ coerce (S.toList m)
    helper m c = do
      x <- randomRIO (0, max-1)
      if x `S.member` m then helper m c else helper (S.insert x m) (c-1)


genTransitions' :: Int -> Generator m (Vector (Vector (Transition m)))
genTransitions' outDegree = do
  wtaSpec <- asks spec
  let n = numStates wtaSpec
      m = IndexedTransition.maxIndex wtaSpec
      desiredEdges = n * outDegree

  transitions <- lift $ map (IndexedTransition.fromIndex wtaSpec) <$> uniqueTransitions desiredEdges m
  weightedTransitions <- (traverse.traverse.traverse) (const genMonoidValue) transitions

  let byState = foldl' (\m (State s, t) -> M.insertWith (++) s [t] m) M.empty weightedTransitions

  let stateVec = V.generate n $ \i -> case M.lookup i byState of
        Nothing -> V.empty
        Just lst -> V.fromList lst

  return stateVec

genWTA :: Generator m (WTA m)
genWTA = WTA <$> asks spec <*> genStates <*> genTransitions
genWTA = asks zeroPolicy >>= \case
  OutDegree d -> WTA <$> asks spec <*> genStates <*> (genTransitions' d)
  Percentage _ -> WTA <$> asks spec <*> genStates <*> genTransitions
+89 −0
Original line number Diff line number Diff line
{-# LANGUAGE ScopedTypeVariables #-}

module IndexedTransition (IndexedTransition(..), maxIndex, fromIndex) where

import           Data.Vector (Vector)
import qualified Data.Vector as V
import Data.Maybe
import Data.Tuple

import Types

import Debug.Trace

newtype IndexedTransition = Index Int
  deriving (Show)

maxIndex :: WTASpec m -> IndexedTransition
maxIndex spec = 
  let n = numStates spec
      (t, _) = transitionsPerState spec
  in Index (n * t)

fromIndex :: WTASpec m -> IndexedTransition -> (State, Transition ())
fromIndex spec (Index i) =
  let n = numStates spec
      (t, symbolSums) = transitionsPerState spec

      (state, stateTransition) = traceShow i $ i `divMod` t

      -- fromJust is justified (ho ho) here, since `stateTransition` should
      -- never be greater than the total number of possible transitions for this
      -- state (which is the last value in symbolSums).
      arity = fromJust (V.findIndex (> stateTransition) symbolSums) - 1

      arityTransition = stateTransition - (symbolSums V.! arity)

      symbolBounds :: Vector Int = V.cons (numSymbols spec V.! arity) (V.replicate arity n)
      symbolDigits = decodeFromInt symbolBounds arityTransition
      symbol = V.head symbolDigits 
      successors = V.tail symbolDigits 


      trans = Transition
        { weight = ()
        , summand = aritySummand spec arity
        , symbol = symbol
        , successors = V.map State successors
        }

  in (State state, trans)

index :: WTASpec m1 -> Int -> Transition m2 -> IndexedTransition
index spec state trans =
  let (t, symbolSums) = transitionsPerState spec

      arity :: Int = summandArity spec (summand trans)
      symbolBounds :: Vector Int = V.cons (numSymbols spec V.! arity) (V.replicate arity (numStates spec))
      arityIdx :: Int = encodeAsInt symbolBounds (V.cons (symbol trans) (V.map fromState $ successors trans)) 
      stateLocal :: Int = symbolSums V.! arity + arityIdx

  in Index $ state * t + stateLocal

-- Helpers

summandArity :: WTASpec m -> Int -> Int
summandArity spec summand = V.findIndices (/= 0) (numSymbols spec) V.! summand

aritySummand :: WTASpec m -> Int -> Int
aritySummand spec arity =
  let arities = numSymbols spec
  in V.length (V.filter (/= 0) (V.take arity arities))

transitionsPerState :: WTASpec m -> (Int, Vector Int)
transitionsPerState spec = 
  let n = numStates spec
      tPerSymbol = (V.imap (\i syms -> syms * n ^ i) (numSymbols spec))
      runningTotal = V.scanl' (+) 0 tPerSymbol
  in (V.last runningTotal, runningTotal)

encodeAsInt :: Vector Int -> Vector Int -> Int
encodeAsInt maxBounds digits =
  let factors = V.prescanr' (*) 1 maxBounds
  in sum (V.zipWith (*) factors digits)

decodeFromInt :: Vector Int -> Int -> Vector Int
decodeFromInt maxBounds encoded =
  V.map fst $ V.postscanr' doDigit (0, encoded) maxBounds

  where doDigit bound (_, current) = (swap $ current `divMod` bound)
+3 −4
Original line number Diff line number Diff line
@@ -27,7 +27,6 @@ import Probability

data SomeMonoid = forall m. SomeMonoid (MonoidType m)

data ZeroFrequency = Percentage Probability | OutDegree Int

data Opts = Opts
  { optMonoid :: SomeMonoid
@@ -158,10 +157,10 @@ main = do

  withSpec opts $ \spec -> do
    randGen <- getStdGen
    let zeroFreq = computeProbability spec (optZeroFrequency opts)
    hPutStrLn stderr $ "p hacking: " ++ show zeroFreq
    -- let zeroFreq = computeProbability spec (optZeroFrequency opts)
    -- hPutStrLn stderr $ "p hacking: " ++ show zeroFreq
    wta <- runGenerator
      (GeneratorConfig spec zeroFreq (optDifferentValues opts))
      (GeneratorConfig spec (optZeroFrequency opts) (optDifferentValues opts))
      genWTA
    putStrLn $ "# Random state for this automaton: '" <> show randGen <> "'"
    T.putStr (Build.toLazyText (buildWTA wta))
+4 −1
Original line number Diff line number Diff line
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}

module Types
  ( MonoidType(..)
@@ -36,7 +39,7 @@ data Transition m = Transition
  , summand :: Int
  , symbol :: Int
  , successors :: Vector State
  } deriving (Show)
  } deriving (Show, Functor, Foldable, Traversable)

data WTA m = WTA
  { spec :: WTASpec m