Я нахожусь в середине переноса оригинальной реализации Дэвидом Блей «Скрытого распределения Дирихле» в Haskell, и я пытаюсь решить, следует ли оставить часть низкоуровневого материала в C. Следующая функция – один пример – это приближение второй производной от lgamma
:
double trigamma(double x) { double p; int i; x=x+6; p=1/(x*x); p=(((((0.075757575757576*p-0.033333333333333)*p+0.0238095238095238) *p-0.033333333333333)*p+0.166666666666667)*p+1)/x+0.5*p; for (i=0; i<6 ;i++) { x=x-1; p=1/(x*x)+p; } return(p); }
Я перевел это в более или менее идиоматический Haskell следующим образом:
trigamma :: Double -> Double trigamma x = snd $ last $ take 7 $ iterate next (x' - 1, p') where x' = x + 6 p = 1 / x' ^ 2 p' = p / 2 + c / x' c = foldr1 (\ab -> (a + b * p)) [1, 1/6, -1/30, 1/42, -1/30, 5/66] next (x, p) = (x - 1, 1 / x ^ 2 + p)
Проблема в том, что когда я запускаю оба параметра Criterion , моя версия Haskell в шесть или семь раз медленнее (я компилирую с -O2
в GHC 6.12.1). Некоторые подобные функции еще хуже.
Я практически ничего не знаю о производительности Haskell, и я не очень заинтересован в том, чтобы выкапывать Core или что-то в этом роде, так как я всегда могу просто назвать несколько математических функций C через FFI.
Но мне любопытно, есть ли плохие плоды, которые я пропускаю – какое-то расширение или библиотека или аннотация, которые я мог бы использовать, чтобы ускорить эту цифровую штуку, не делая ее слишком уродливой.
ОБНОВЛЕНИЕ: Вот два лучших решения, благодаря Дон Стюарту и Иццу . Я немного изменил ответ Ицца, чтобы использовать Data.Vector
.
invSq x = 1 / (x * x) computeP x = (((((5/66*p-1/30)*p+1/42)*p-1/30)*p+1/6)*p+1)/x+0.5*p where p = invSq x trigamma_d :: Double -> Double trigamma_d x = go 0 (x + 5) $ computeP $ x + 6 where go :: Int -> Double -> Double -> Double go !i !x !p | i >= 6 = p | otherwise = go (i+1) (x-1) (1 / (x*x) + p) trigamma_y :: Double -> Double trigamma_y x = V.foldl' (+) (computeP $ x + 6) $ V.map invSq $ V.enumFromN x 6
Производительность двух, кажется, почти точно такая же: выигрыш одного или другого на процентный пункт или два в зависимости от флагов компилятора.
Как сказал Camcann в Reddit , мораль этой истории – «Для достижения наилучших результатов используйте Don Stewart в качестве генератора кода GHC». Запрет на это решение, самая безопасная ставка, по-видимому, заключается в простом преобразовании структур управления C непосредственно в Haskell, хотя слияние фьюжн может придать аналогичную производительность в более идиоматическом стиле.
Вероятно, в конечном итоге я использую метод Data.Vector
в своем коде.
Используйте те же структуры управления и данных, что дает:
{-# LANGUAGE BangPatterns #-} {-# OPTIONS_GHC -fvia-C -optc-O3 -fexcess-precision -optc-march=native #-} {-# INLINE trigamma #-} trigamma :: Double -> Double trigamma x = go 0 (x' - 1) p' where x' = x + 6 p = 1 / (x' * x') p' =(((((0.075757575757576*p-0.033333333333333)*p+0.0238095238095238) *p-0.033333333333333)*p+0.166666666666667)*p+1)/x'+0.5*p go :: Int -> Double -> Double -> Double go !i !x !p | i >= 6 = p | otherwise = go (i+1) (x-1) (1 / (x*x) + p)
У меня нет вашего testuite, но это дает следующий asm:
A_zdwgo_info: cmpq $5, %r14 jg .L3 movsd .LC0(%rip), %xmm7 movapd %xmm5, %xmm8 movapd %xmm7, %xmm9 mulsd %xmm5, %xmm8 leaq 1(%r14), %r14 divsd %xmm8, %xmm9 subsd %xmm7, %xmm5 addsd %xmm9, %xmm6 jmp A_zdwgo_info
Что выглядит нормально. Это тот код, который -fllvm
бэкэнд -fllvm
.
GCC разворачивает цикл, и единственный способ сделать это – либо с помощью Template Haskell, либо с ручным разворачиванием. Вы можете подумать, что (макрос TH), если вы делаете много этого.
Фактически, бэкэнд LLVM GHC разворачивает цикл 🙂
Наконец, если вам действительно нравится оригинальная версия Haskell, напишите ее с помощью комбинаторов streamов, и GHC преобразует их обратно в циклы. (Упражнение для читателя).
Перед оптимизацией я бы не сказал, что ваш оригинальный перевод – самый идиоматический способ выразить в Haskell то, что делает код C.
Как бы процесс оптимизации продолжался, если бы мы начали со следующего:
trigamma :: Double -> Double trigamma x = foldl' (+) p' . map invSq . take 6 . iterate (+ 1) $ x where invSq y = 1 / (y * y) x' = x + 6 p = invSq x' p' =(((((0.075757575757576*p-0.033333333333333)*p+0.0238095238095238) *p-0.033333333333333)*p+0.166666666666667)*p+1)/x'+0.5*p