Skip to content

Commit

Permalink
tinker: Simplex
Browse files Browse the repository at this point in the history
  • Loading branch information
gruhn committed Mar 13, 2023
1 parent 590ba6b commit 9c92834
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 10 deletions.
2 changes: 1 addition & 1 deletion SMT.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ library
, Theory.UninterpretedFunctions.Lazy
, Theory.UninterpretedFunctions.Eager
, Theory.LinearArithmatic.FourierMotzkin
, Theory.LinearArithmatic.Simplex
-- , Theory.LinearArithmatic.Simplex
, Theory.LinearArithmatic.BranchAndBound
, Theory.NonLinearRealArithmatic.Expr
, Theory.NonLinearRealArithmatic.Interval
Expand Down
100 changes: 91 additions & 9 deletions src/Theory/LinearArithmatic/Simplex.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{-# LANGUAGE TupleSections #-}
module Theory.LinearArithmatic.Simplex () where

import qualified Data.IntMap as M
import qualified Data.IntMap.Merge.Lazy as MM
import qualified Data.IntSet as S
import qualified Control.Monad.State as State
import Control.Monad.State (State)
import Control.Monad (unless, guard)

type Var = Int
Expand All @@ -12,20 +12,102 @@ type Var = Int
type Assignment a = M.IntMap a

-- | Map from variable IDs to coefficients
type Constraint a = M.IntMap a
type LinearTerm a = M.IntMap a

data BoundType = UpperBound | LowerBound

data Tableau a = Tableau
{ getNonBasis :: M.IntMap (Constraint a)
{ getNonBasis :: M.IntMap (LinearTerm a)
, getBounds :: M.IntMap (BoundType, a)
, getAssignment :: Assignment a
}

data BoundViolation = MustIncrease | MustDecrease

type Equation a = (Var, LinearTerm a)

{-|
Solve an equation for a given variable. For example solving
y = 3x + 10
for x, yields
x = 1/3 y - 10/3
-}
solveFor :: Equation a -> Var -> Maybe (Equation a)
solveFor (y, right_hand_side) x = do
coeff_of_x <- M.lookup x right_hand_side
guard (coeff_of_x /= 0)
return $ (x,)
-- divide by coefficient: x = -1/3y - 10/3
$ M.map (/ (-coeff_of_x))
-- exchange x and y: -3x = -y + 10
$ M.insert y (-1)
$ M.delete x right_hand_side

{-|
Given two equations, such as
e1: x = w + 2z
e2: y = 3x + 4z
rewrite x in e2 using e1:
y = 3(w + 2z) + 4z
= 3w + 10z
-}
rewrite :: Equation a -> Equation a -> Equation a
rewrite (x, term_x) (y, term_y) =
let
coeff_of_x = M.findWithDefault 0 x term_y
term_x' = M.map (* coeff_of_x) term_x
term_y' = M.unionWith (+) (M.delete x term_y) term_x'
in
(y, term_y')

eval :: Assignment a -> LinearTerm a -> a
eval assignment term = sum . MM.zipWithMatched (const (*)) assignment term

{-|
TODO: make sure:
- only slack variables are pivoted
- non-basic variables must violate bound
- basic variable is suitable for pivoting
-}
pivot' :: Var -> Var -> Tableau a -> Tableau a
pivot' basic_var non_basic_var (Tableau non_basis bounds assignment) =
let
from_just msg (Just a) = a
from_just msg Nothing = error msg

term = from_just "Given variable is not actually in the non-basis ==> invalid pivot pair"
$ M.lookup non_basic_var non_basis

equation = from_just "Can't solve for basic variable ==> invalid pivot pair"
$ solveFor (non_basic_var, term) basic_var

non_basis' = M.fromList $ rewrite equation <$> M.toList non_basis

old_value_basic_var = assignment M.! basic_var
new_value_basic_var = from_just "Basic variable doesn't have a bound so it's not actually violated"
$ snd <$> M.lookup basic_var bounds

basic_var_coeff = _

old_value_non_basic_var = non_basis M.! non_basic_bar
new_value_non_basic_var = old_value_non_basic_var + (old_value_basic_bar - new_value_basic_var) / basic_var_coeff

assignment' = M.union (eval <$> non_basis')
$ M.insert non_basic_var new_value_non_basic_var
$ M.insert basic_var new_value_basic_var assignment
in
Tableau non_basis' bounds assignment'

pivot :: forall a. (Num a, Ord a) => Tableau a -> Tableau a
pivot (Tableau non_basis bounds assignment) =
pivot (Tableau non_basis bounds assignment) =
let
basis = M.keysSet assignment S.\\ M.keysSet non_basis

Expand All @@ -48,11 +130,11 @@ pivot (Tableau non_basis bounds assignment) =
pivot_candidates :: [ (Var, Var) ]
pivot_candidates = do
(basic_var, violation) <- violated_basic_vars
(non_basic_var, constraint) <- M.toAscList non_basis
(non_basic_var, term) <- M.toAscList non_basis

let basic_var_coeff = M.findWithDefault 0 basic_var constraint
let basic_var_coeff = M.findWithDefault 0 basic_var term
bound_type = fst <$> M.lookup non_basic_var bounds
can_pivot =
can_pivot =
case (bound_type, signum basic_var_coeff, violation) of
-- If the coefficient of the basic variable is 0, then the value of the variable
-- is not affected by pivoting, so it can't resolve the bound violation.
Expand Down Expand Up @@ -87,7 +169,7 @@ pivot (Tableau non_basis bounds assignment) =
-- variable value larger.
(Just LowerBound, -1, MustDecrease) -> True

-- In all other cases the bound of the non-basic variable would be vioalted by
-- In all other cases the bound of the non-basic variable would be violated by
-- pivoting.
all_other_cases -> False

Expand Down

0 comments on commit 9c92834

Please sign in to comment.