Hatena::Grouphaskell

mokeheheの日記

2009-08-03

電卓で任意の型の値を扱えるようにする

今までは計算の中でDoubleの値しか扱えなかったのを、任意の型を扱えるように変更する。

まずevalで返す値をDouble型から任意の型を表すValue型に変更する:

eval :: Node -> MyST Value

あとはコンパイラが出すエラーに従ってソースを修正していけばOK。Haskellはこの安心感がたまらない。また今まで比較演算子が1.0(真)や0.0(偽)を返していたところをバリアントのBool値を返すように変更する。

false   = DBool False
true    = DBool True

今までは演算ノードのevalはそのまま取り出したDouble値同士で演算すればよかったが、ちゃんと値の型を見て分岐させるように変更する。

ひとつハマったことは、比較演算子で取り出した型の値で演算を行うと型が固定されてしまって別の型に適用できなくなってしまった:

arith LT n1 n2 = evalOrd (<) n1 n2
evalOrd op n1 n2 = do
	v1 <- eval n1
	v2 <- eval n2
	case (v1, v2) of
		-- エラー:opが上の行でDouble型に適用されてしまっているため、下の行のBool型に適用できない
		(DDouble d1, DDouble d1) -> return $ DBool $ d1 `op` d2
		(DBool   b1, DBool   b1) -> return $ DBool $ b1 `op` b2
		_                        -> fail "illegal cmp"

(<)の型が(Ord a)=>a->a->Boolのまま扱われないのが謎だった。Value型をOrd型クラスにして、Value型のまま比較して回避した。Value型は関数も含んでいてderiving句が使えないのでしかたなく手書き。なぜOrdクラスのインスタンスを書くとき、(<)だけを定義すればあとはデフォルトの定義で何とかしてくれないのか?Eqクラスは(==)か(/=)のどちらかを定義すれば大丈夫なのに…。

今までは呼び出し先の関数は常に関数名で参照していたところを、先頭ノードの評価結果の関数に対してapplyするように変更:

eval (Funcall fnode params) = do
	fval <- eval fnode
	evaledParams <- evalNodeList params
	st <- get
	call fnode fval evaledParams

ネイティブ関数の呼び出しは、手抜きで引数はDouble値だと決め付けてしまっている。

true, false を定義。

Main.hs
import Prelude hiding (catch)
import Control.Exception (catch)
import Text.ParserCombinators.Parsec (runParser, eof)
import System.IO

import Parser (stmt)
import Intp (intp, IntpState(..), Value(..))

repl :: String -> (String -> Bool) -> (String -> st -> (String, st)) -> st -> IO st
repl prompt bQuit eval = loop
	where
		loop st = hSetBuffering stdout NoBuffering >> putStr prompt >> getLine >>= act st
		act st s
			| bQuit s = return st
			| otherwise = catch (exec st s) (\exc -> print exc >> loop st)
		exec st s = do
			let (res, st') = eval s st
			putStrLn res
			loop st'

calc :: IntpState -> IO IntpState
calc = repl "> " (== ":q") parse
	where
		parse line st     = either (\err -> (show err, st)) (evaluate st) $ runParser parser () "" line
		evaluate st node  = result $ intp node st
		result (res, st)  = (show res, st)
		parser = stmt >>= \e -> eof >> return e

-- 初期の状態
initialState = IntpState genv
	where
		genv = [
			("pi", DDouble pi),
			("true",  DBool True),
			("false", DBool False),
			("sin", Native 1 (apply1 sin)),
			("cos", Native 1 (apply1 cos)),
			("tan", Native 1 (apply1 tan)),
			("log", Native 1 (apply1 log)),
			("sqrt", Native 1 (apply1 sqrt))
			]
		apply1 f = \[x] -> f x

main = putStrLn "type ':q' to quit." >> calc initialState >> putStrLn "Bye"
Parser.hs
module Parser (
	Node(..),
	Op(..),
	stmt
) where

import Prelude hiding (EQ, LT, GT)
import Text.ParserCombinators.Parsec
import Text.ParserCombinators.Parsec.Expr
import qualified Text.ParserCombinators.Parsec.Token as P
import Text.ParserCombinators.Parsec.Language

-- 演算子
data Op = Add | Sub | Mul | Div | Pow | Negate
	| EQ | NE | LT | LE | GT | GE
	| LogiAnd | LogiOr
	deriving (Show)

-- ノード
data Node =
		Literal Double
	|	Arith Op Node Node
	|	Ident String
	|	Assign String Node
	|	If Node Node Node
	|	Funcall Node [Node]
	|	Defun String [String] Node
	deriving (Show)

type MyParserState = ()

-- パーサ型
type MyParser a = GenParser Char MyParserState a

lexer :: P.TokenParser MyParserState
lexer = P.makeTokenParser (haskellDef { reservedOpNames = ["*","/","+","-","**", "==", "/=", "<", "<=", ">", ">=", "&&", "||", "?", ":"] })

naturalOrFloat = P.naturalOrFloat lexer
parens         = P.parens lexer
reservedOp     = P.reservedOp lexer
identifier     = P.identifier lexer
lexeme         = P.lexeme lexer

---------------------------------------
-- Parser (Generate Abstract-Syntax-Tree)

stmt :: MyParser Node
stmt = try(defun) <|> expr

expr :: MyParser Node
expr = assignExpr

assignExpr :: MyParser Node
assignExpr = try(assign) <|> condExpr
	where
		assign = do
			var <- identifier
			lexeme $ char '='
			e <- expr
			return $ Assign var e

condExpr :: MyParser Node
condExpr = try(cond) <|> expr'
	where
		cond = do
			c <- expr'
			lexeme $ char '?'
			t <- expr
			lexeme $ char ':'
			e <- expr
			return $ If c t e

expr' :: MyParser Node
expr' = buildExpressionParser table factor <?> "expression"
	where
		table = [
			[unary "-" (Arith Negate (Literal 0)), unary "+" id],
			[op "**" (Arith Pow) AssocRight],
			[op "*" (Arith Mul) AssocLeft, op "/" (Arith Div) AssocLeft],
			[op "+" (Arith Add) AssocLeft, op "-" (Arith Sub) AssocLeft],
			[op "==" (Arith EQ) AssocNone, op "/=" (Arith NE) AssocNone, op "<" (Arith LT) AssocNone, op "<=" (Arith LE) AssocNone, op ">" (Arith GT) AssocLeft, op ">=" (Arith GE) AssocNone],
			[op "&&" (Arith LogiAnd) AssocLeft],
			[op "||" (Arith LogiOr) AssocLeft]
			]
		op s f assoc = Infix (do{ reservedOp s; return f } <?> "operator") assoc
		unary s f = Prefix (do{ reservedOp s; return f })

factor :: MyParser Node
factor = try(funcall) <|> primFactor

funcall :: MyParser Node
funcall = do
	fnode <- primFactor
	params <- many1 primFactor
	return $ Funcall fnode params

primFactor :: MyParser Node
primFactor = (parens expr) <|> floatLiteral <|> varref <?> "factor"

floatLiteral :: MyParser Node
floatLiteral = naturalOrFloat >>= return . either (Literal . fromInteger) Literal

varref :: MyParser Node
varref = do
	name <- identifier
	return $ Ident name

defun :: MyParser Node
defun = do
	name <- identifier
	args <- many1 identifier
	lexeme $ char '='
	e <- expr
	return $ Defun name args e
Intp.hs
module Intp (
	intp,
	IntpState(..),
	Value(..)
) where

import Prelude hiding (EQ, LT, GT, catch)
import Parser (Node(..), Op(..))
import Control.Monad.State

-- 値
data Value =
		DDouble Double                   -- 直値
	|	DBool Bool                       -- 真偽値
	|	Function [String] Node           -- 関数
	|	Native Int ([Double] -> Double)  -- ネイティブ関数

instance Show Value where
	show (DDouble d)    = show d
	show (DBool True)   = "true"
	show (DBool False)  = "false"
	show (Function _ _) = "<function>"
	show (Native _ _)   = "<native>"

instance Eq Value where
	(DDouble d1) == (DDouble d2)  = d1 == d2
	(DBool   b1) == (DBool   b2)  = b1 == b2
	_           == _              = False

instance Ord Value where
	(DDouble d1) <  (DDouble d2)  = d1 <  d2
	(DBool   b1) <  (DBool   b2)  = b1 <  b2
	(DDouble d1) <= (DDouble d2)  = d1 <= d2
	(DBool   b1) <= (DBool   b2)  = b1 <= b2
	(DDouble d1) >  (DDouble d2)  = d1 >  d2
	(DBool   b1) >  (DBool   b2)  = b1 >  b2
	(DDouble d1) >= (DDouble d2)  = d1 >= d2
	(DBool   b1) >= (DBool   b2)  = b1 >= b2

-- 環境
type Environment = [(String, Value)]

doAssign :: String -> Value -> Environment -> Environment
doAssign var val env = (var, val) : filter ((/= var) . fst) env

-- インタプリタの状態
data IntpState =
	IntpState {
		global :: Environment
		}
	deriving (Show)

-- インタープリト
intp :: Node -> IntpState -> (Value, IntpState)
intp node state = runState (eval node) state

---------------------------------------
-- Intp

false   = DBool False
true    = DBool True
isFalse = (== false)
isTrue  = (/= false)

type MyST a = State IntpState a

eval :: Node -> MyST Value
eval (Literal v)      = return $ DDouble v
eval (Arith op n1 n2) = arith op n1 n2
eval (Ident name) = do
	st <- get
	case (lookup name $ global st) of
		Nothing -> fail $ "undefined variable: " ++ name
		Just v  -> return v
eval (Assign name n) = do
	v <- eval n
	modify (\st -> st{ global = doAssign name v (global st) })
	return v
eval (If c t e) = do
	v <- eval c
	eval $ if isTrue v then t else e
eval (Funcall fnode params) = do
	fval <- eval fnode
	evaledParams <- evalNodeList params
	st <- get
	call fnode fval evaledParams
eval (Defun name args body) = do
	modify (\st -> st { global = doAssign name f (global st) })
	return f
	where f = Function args body

arith Add     n1 n2 = evalArith (+)  n1 n2
arith Sub     n1 n2 = evalArith (-)  n1 n2
arith Mul     n1 n2 = evalArith (*)  n1 n2
arith Div     n1 n2 = evalArith (/)  n1 n2
arith Pow     n1 n2 = evalArith (**) n1 n2
arith EQ      n1 n2 = evalEq    (==) n1 n2
arith NE      n1 n2 = evalEq    (/=) n1 n2
arith LT      n1 n2 = evalOrd   (<)  n1 n2
arith LE      n1 n2 = evalOrd   (<=) n1 n2
arith GT      n1 n2 = evalOrd   (>)  n1 n2
arith GE      n1 n2 = evalOrd   (>=) n1 n2
arith LogiAnd n1 n2 = evalShortcut isFalse n1 n2
arith LogiOr  n1 n2 = evalShortcut isTrue  n1 n2
arith Negate  n1 n2 = evalArith (-)  n1 n2

evalArith op n1 n2 = do
	v1 <- eval n1
	v2 <- eval n2
	case (v1, v2) of
		(DDouble d1, DDouble d2) -> return $ DDouble $ d1 `op` d2
		_                        -> fail "illegal operation"
evalEq op n1 n2 = do
	v1 <- eval n1
	v2 <- eval n2
	return $ DBool $ v1 `op` v2
evalOrd op n1 n2 = do
	v1 <- eval n1
	v2 <- eval n2
	case (v1, v2) of
		(DDouble _, DDouble _) -> return $ DBool $ v1 `op` v2
		(DBool   _, DBool   _) -> return $ DBool $ v1 `op` v2
		_                      -> fail "illegal operation"
evalShortcut f n1 n2 = do
	v1 <- eval n1
	if f v1 then return v1 else eval n2

evalNodeList :: [Node] -> MyST [Value]
evalNodeList []     = return []
evalNodeList (x:xs) = do
	v <- eval x
	vs <- evalNodeList xs
	return $ v:vs

call :: Node -> Value -> [Value] -> MyST Value
call fnode (Function args body) params
	| length params /= length args  = fail $ show fnode ++ ": illegal argnum, " ++ show (length params) ++ " for " ++ show (length args)
	| otherwise = do
		st <- get
		put $ st { global = override args params (global st) }
		res <- eval body
		put st
		return res
	where
		override args params env =
			foldl (\e (a,p) -> doAssign a p e) env $ zip args params
call fnode (Native argnum f) params
	| length params /= argnum = fail $ show fnode ++ ": illegal argnum, " ++ show (length params) ++ " for " ++ show argnum
	| otherwise = return $ DDouble $ f $ map (\(DDouble v) -> v) params
call fnode v _ = fail $ show fnode ++ " is not function"

GHC6.10でControl.Exceptionが更新されていた

GHC6.10 でControl.Exceptionのcatchが変更されていたらしい。GHC6.8.3を使ってたので気づかなかった…。古いものとの互換バージョンはControl.OldExceptionになってるらしい。うげー。

新しいControl.ExceptionのcatchでStateモナド中のfailで起こした例外をキャッチするには

import Control.Exception (catch, SomeException)

  ...
  catch action (\exc -> print (exc :: SomeException))

と、SomeExceptionで受けるらしい。SomeExceptionて名前はどうなの、と思うのだけど。

トラックバック - http://haskell.g.hatena.ne.jp/mokehehe/20090803