Haskellで麻雀の待ち判定プログラムを書いた

makeplex salon:あなたのスキルで飯は食えるか? 史上最大のコーディングスキル判定 (1/2) - ITmedia エンタープライズ
ちょっと流行に乗り遅れた感があるけど書いてみた。どうやら自分の実力ではHaskellで飯を食うのは無理らしいw

ソースコード

前半はユーティリティ関数やデータ構築子の定義などで,本質的なのは後半のsearchPair以降です。

module Main where

import Data.List (delete, nub, elemIndices, foldl', sort)
import Data.Char (digitToInt)
import Control.Monad
import System.IO (hFlush, stdout)

count :: (Eq a) => a -> [a] -> Int
count = countBy . (==)

countBy :: (a -> Bool) -> [a] -> Int
countBy f = foldl' (\s x -> if f x then s+1 else s) 0

data Pair = Shuntsu Int | Kotsu Int | Jantoh Int | Wait Int Pair
          deriving (Eq, Ord)

instance Show Pair where
  show p @ (Wait i _) = "[" ++ concatMap show (toList p) ++ "|" ++ show i ++ "]"
  show p              = "(" ++ concatMap show (toList p) ++ ")"

toList :: Pair -> [Int]
toList (Shuntsu i) = [i..i+2]
toList (Kotsu   i) = [i, i, i]
toList (Jantoh  i) = [i, i]
toList (Wait i p)  = delete i $ toList p

removePair :: Pair -> [Int] -> [Int]
removePair p xs = foldl' (flip delete) xs (toList p)

isShuntsu, isKotsu, isJantoh, isWait  :: Pair -> Bool
isShuntsu (Shuntsu _) = True
isShuntsu _           = False
isKotsu   (Kotsu   _) = True
isKotsu   _           = False
isJantoh  (Jantoh  _) = True
isJantoh  _           = False
isWait    (Wait _  _) = True
isWait    _           = False



type ParseResult a = [(a, [Int])]
newtype Parser a = Parser ([Int] -> ParseResult a)

parse :: Parser a -> [Int] -> ParseResult a
parse (Parser p) = p

instance Monad Parser where
  return a = Parser $ \xs -> [(a, xs)]
  p >>= f  = Parser $ \cs -> concat [parse (f a) cs' | (a, cs') <- parse p cs]

instance MonadPlus Parser where
  mzero     = Parser $ \_  -> []
  mplus p q = Parser $ \cs -> parse p cs ++ parse q cs

eof :: Parser ()
eof = Parser eof'
  where eof' [] = [((), [])]
        eof' _  = []

manyN :: Int -> Parser a -> Parser [a]
manyN 1 p = p >>= \x -> return [x]
manyN n p = p >>= \x -> manyN (n-1) p >>= \xs -> return (x:xs)



searchPair :: (Int -> [Int] -> Bool) -> (Int -> Pair) -> Parser Pair
searchPair filt cntr = Parser (\is -> search is (nub is))
  where search xs (x:xs') | filt x xs = [(p, removePair p xs)]
                          | otherwise = search xs xs'
          where p = cntr x
        search _ [] = []

shuntsu, kotsu, jantoh, waitShuntsu, waitKotsu, waitJantoh, wait :: Parser Pair
shuntsu = searchPair (\x xs -> elem (x+1) xs && elem (x+2) xs) Shuntsu
kotsu   = searchPair (\x xs -> length (elemIndices x xs) >= 3) Kotsu
jantoh  = searchPair (\x xs -> length (elemIndices x xs) >= 2) Jantoh
waitShuntsu =
  searchPair (\x xs -> elem (x+1) xs && x <= 7) (\x -> Wait (x+2) (Shuntsu x))
  `mplus`
  searchPair (\x xs -> elem (x+1) xs && x >= 2) (\x -> Wait (x-1) (Shuntsu (x-1)))
  `mplus`
  searchPair (\x xs -> elem (x+2) xs) (\x -> Wait (x+1) (Shuntsu x))
waitKotsu =
  searchPair (\x xs -> length (elemIndices x xs) >= 2) (\x -> Wait x (Kotsu x))
waitJantoh =
  searchPair (\x xs -> length (elemIndices x xs) >= 1) (\x -> Wait x (Jantoh x))
wait = waitShuntsu `mplus` waitKotsu `mplus` waitJantoh

pattern :: Parser [Pair]
pattern = normal `mplus` chitoi >>= \xs -> eof >> return xs
  where normal  = pairs >>= jFilt >>= wFilt
        chitoi  = manyN 7 (jantoh `mplus` waitJantoh) >>= wFilt
        pairs   = manyN 5 $ shuntsu `mplus` kotsu `mplus` jantoh `mplus` wait
        jFilt ps = if countBy (isJantoh) ps > 1 then mzero else return ps
        wFilt ps = if countBy (isWait)   ps > 1 then mzero else return ps

parsePattern :: [Int] -> ParseResult [Pair]
parsePattern = nub . map (\ (x, y) -> (sort x, y)) . parse pattern



parseText :: String -> [Int]
parseText = map digitToInt

showResult :: ParseResult [Pair] -> String
showResult = unlines . map (concatMap show . fst)

main :: IO ()
main = do
  ls <- getContents >>= return . lines
  mapM_ (putStr . showResult . parsePattern . parseText) ls
  hFlush stdout

実行結果

$ echo "1112224588899" | ./mahjang
(111)(222)(888)(99)[45|6]
(111)(222)(888)(99)[45|3]

$ echo "1122335556799" | ./mahjang
(123)(123)(567)(55)[99|9]
(123)(123)(567)(99)[55|5]
(123)(123)(555)(99)[67|8]
(123)(123)(555)(99)[67|5]

$ echo "1112223335559" | ./mahjang
(123)(123)(123)(555)[9|9]
(111)(222)(333)(555)[9|9]

$ echo "1223344888999" | ./mahjang
(123)(234)(888)(999)[4|4]
(234)(234)(888)(999)[1|1]
(123)(888)(999)(44)[23|4]
(123)(888)(999)(44)[23|1]

$ echo "1112345678999" | ./mahjang
(234)(567)(111)(999)[8|8]
(234)(678)(111)(999)[5|5]
(345)(678)(111)(999)[2|2]
(123)(456)(789)(11)[99|9]
(123)(456)(789)(99)[11|1]
(123)(456)(999)(11)[78|9]
(123)(456)(999)(11)[78|6]
(123)(678)(999)(11)[45|6]
(123)(678)(999)(11)[45|3]
(234)(567)(111)(99)[89|7]
(234)(789)(111)(99)[56|7]
(234)(789)(111)(99)[56|4]
(456)(789)(111)(99)[23|4]
(456)(789)(111)(99)[23|1]
(345)(678)(999)(11)[12|3]

とりあえずちゃんと動作しているようです。これで清一色を上がるときも安心!
七対子にも簡単に対応できました。

$ echo "1133557799223" | ./mahjang
(11)(22)(33)(55)(77)(99)[3|3]

実装方法としては,きちんとデータ構造定義して,パーサコンビネータを使って書いてみました。重複を無視して候補を全て列挙した後,重複を取り除くというやり方をしているので効率は悪いです。もうちょっと工夫の仕様はありそう。あと出力のフォーマットが設問と異なりますがそこはまあ気にしない。

追記

新たな関数combinationを追加して,ちょっとだけ効率よく探索できるようにした…つもり。

replaceAt :: Int -> a -> [a] -> [a]
replaceAt n x = replaceAtBy n (const x)

replaceAtBy :: Int -> (a -> a) -> [a] -> [a]
replaceAtBy n f xs = hd ++ f (head tl) : tail tl
  where (hd, tl) = splitAt n xs

combination :: [(Int, Parser a)] -> Parser [a]
combination [] = return []
combination ps = snd (foldl' f (0, mzero) ps) >>= \(i, x) ->
  combination (dec i ps) >>= \xs ->
  return (x:xs)
    where
      f :: (Int, Parser (Int, a)) -> (Int, Parser a) -> (Int, Parser (Int, a))
      f (i, ps') (_, p) = (i+1, ps' `mplus` (p >>= \x -> return (i, x)))
      dec i ps' = filter ((>0) . fst) $ replaceAtBy i (\(c, p) -> (c-1, p)) ps'

pattern :: Parser [Pair]
pattern = normal `mplus` chitoi >>= \xs -> eof >> return xs
  where normal = combination [(4, shuntsu `mplus` kotsu), (1, waitJantoh)]
                 `mplus`
                 combination [(3, shuntsu `mplus` kotsu), (1, jantoh),
                              (1, waitShuntsu `mplus` waitKotsu)]
        chitoi = combination [(6, jantoh), (1, waitJantoh)]

速度は割と速くなった気がします。探索時に"(123)(111)"と"(111)(123)"を別の物とみなして列挙しているので,まだまだ改良の余地はありそう。だけどとりあえずここで打ち止め。

さらに追記

mplusの重複を省いてくれる版な演算子 <+> を定義して使うことにした。[(123)(111)(11), (111)(123)(11), (111)(11)(123)] のように等価な要素をその都度その都度1まとめにしてくれます。ただ比較のコストが大きそうなので逆に遅くなってるかも?

(<+>) :: (Eq a) => Parser a -> Parser a -> Parser a
p <+> q = Parser $ \cs -> nub (parse p cs ++ parse q cs)

isSame :: (Eq a) => [a] -> [a] -> Bool
isSame xs ys = length xs == length ys && (null $ foldl' (flip delete) xs ys)

data Comb a = CCons a (Comb a) | CNil
instance (Eq a) => Eq (Comb a) where
  xs == ys = isSame (combToList xs) (combToList ys)

combToList :: Comb a -> [a]
combToList (CNil)       = []
combToList (CCons x xs) = x : combToList xs

combination :: (Eq a) => [(Int, Parser a)] -> Parser (Comb a)
combination [] = return CNil
combination ps = snd (foldl' f (0, mzero) ps) >>= \(i, x) ->
  combination (dec i ps) >>= \xs ->
  return (CCons x xs)
    where
      f (i, ps') (_, p) = (i+1, ps' <+> (p >>= \x -> return (i, x)))
      dec i ps' = filter ((>0) . fst) $ replaceAtBy i (\(c, p) -> (c-1, p)) ps'

parsePattern :: [Int] -> ParseResult [Pair]
parsePattern = map (\ (x, y) -> (sort x, y)) . parse pattern

わざわざCombなんてのを定義しなくても標準ライブラリに集合を表すコンテナとか用意されてる気がするな…。