Chombo + EB  3.0
RootSolver.H
Go to the documentation of this file.
1 #ifdef CH_LANG_CC
2 /*
3  * _______ __
4  * / ___/ / ___ __ _ / / ___
5  * / /__/ _ \/ _ \/ V \/ _ \/ _ \
6  * \___/_//_/\___/_/_/_/_.__/\___/
8  */
9 #endif
10
11 #ifndef _ROOTSOLVER_H_
12 #define _ROOTSOLVER_H_
13
14
15 /******************************************************************************/
16 /**
17  * \file
18  *
19  * \brief Root solvers
20  *
21  *//*+*************************************************************************/
22
23 #include <cmath>
24 #include <algorithm>
25
26 #include "CH_assert.H"
27
29
30 namespace RootSolver
31 {
32
33 template <typename T> struct RootTr
34 {
35 };
36
37 // Default epsilon and tolerance for floats
38 template <> struct RootTr<float>
39 {
40  enum
41  {
42  maxIter = 100
43  };
44
45  static float eps()
46  {
47  return 3.0E-7;
48  }
49
50  static float tolerance()
51  {
52  return 1.0e-6;
53  }
54 };
55
56 // Default epsilon and tolerance for doubles
57 template <> struct RootTr<double>
58 {
59  enum
60  {
61  maxIter = 100
62  };
63
64  static double eps()
65  {
66  return 3.0E-15;
67  }
68
69  static double tolerance()
70  {
71  return 1.0e-12;
72  }
73 };
74
75 /*******************************************************************************
76  */
77 /// Brent's root solver
78 /**
79  * \tparam T Type for x and f(x) - must be floating point
80  * \tparam Func Function object describing function to solve where
81  * T f(x) = Func::operator()(const T& x)
82  *
83  * \param[out] numIter
84  * Number of iterations required for convergence to
85  * specified tolerance. If equal to MAXITER, the solution
86  * is not within specified tolerance.
87  * \param[in] f Instance of function to solve
88  * \param[in] aPt Lower bound
89  * \param[in] bPt Upper bound
90  * \param[in] tol Tolerance for solve - essentially the spread of the
91  * bracket. This can be specified in absolute terms, or,
92  * if given by an integer cast to T, the number of
93  * significant digits to solve for in x. The default
94  * is given by the RootTr class. Note that epsilon is
95  * also considered for specifying the spread of the
96  * bracket.
97  * \param[in] MAXITER
98  * Maximum iterations. Default (100). If reached,
99  * a message is written to cerr but the program otherwise
100  * completes
101  * \return x where f(x) = 0
102  *
103  * Example \verbatim
104  * #include <functional>
105  * #include "RootSolver.H"
106  * // Func is not allowed to be local until the C++0x standard
107  * struct Func : public std::unary_function<Real, Real>
108  * {
109  * Real operator()(const Real& a_x) const
110  * {
111  * return 5*std::pow(a_x, 5) - 3*std::pow(a_x, 3) + a_x;
112  * }
113  * };
114  * void foo()
115  * {
116  * int numIter;
117  * const Real xmin = -1.;
118  * const Real xmax = 1.;
119  * const Real x0 = RootSolver::Brent(numIter, Func(), xmin, xmax);
120  * if (numIter == RootTr<Real>::maxIter)
121  * {
122  * std::cout << "Uh oh\n";
123  * }
124  * }
125  * \endverbatim
126  *//*+*************************************************************************/
127
128 template <typename T, typename Func>
129 T Brent(int& numIter,
130  const Func& f,
131  T aPt,
132  T bPt,
133  T tol = RootTr<T>::tolerance(),
134  const unsigned MAXITER = RootTr<T>::maxIter)
135 {
136  CH_assert(tol >= 0.);
137  CH_assert(MAXITER > 0);
138  const T eps = RootTr<T>::eps();
139  const int prec = (int)tol;
140  if (((T)prec) == tol)
141  {
142  tol = std::pow(10., -std::abs(prec));
143  }
144  numIter = -1;
145
146  // Max allowed iterations and floating point precision
147  unsigned i;
148  T c, d, e;
149  T p, q, r, s;
150
151  T fa = f(aPt);
152  T fb = f(bPt);
153
154  // Init these to be safe
155  c = d = e = 0.0;
156
157  if (fb*fa > 0)
158  {
159  MayDay::Abort("RootSolver::Brent: Root must be bracketed");
160  }
161
162  T fc = fb;
163
164  for (i = 0; i < MAXITER; i++)
165  {
166  if (fb*fc > 0)
167  {
168  // Rename a, b, c and adjust bounding interval d
169  c = aPt;
170  fc = fa;
171  d = bPt - aPt;
172  e = d;
173  }
174
175  if (std::fabs(fc) < std::fabs(fb))
176  {
177  aPt = bPt;
178  bPt = c;
179  c = aPt;
180  fa = fb;
181  fb = fc;
182  fc = fa;
183  }
184
185  // Convergence check
186  const T tol1 = 2.0 * eps * std::fabs(bPt) + 0.5 * tol;
187  const T xm = 0.5 * (c - bPt);
188
189  if (std::fabs(xm) <= tol1 || fb == 0.0)
190  {
191  break;
192  }
193
194  if (std::fabs(e) >= tol1 && std::fabs(fa) > std::fabs(fb))
195  {
196  // Attempt inverse quadratic interpolation
197  s = fb / fa;
198  if (aPt == c)
199  {
200  p = 2.0 * xm * s;
201  q = 1.0 - s;
202  }
203  else
204  {
205  q = fa / fc;
206  r = fb / fc;
207  p = s * (2.0 * xm * q * (q-r) - (bPt-aPt) * (r-1.0));
208  q = (q-1.0) * (r-1.0) * (s-1.0);
209  }
210
211  // Check whether in bounds
212  if (p > 0) q = -q;
213
214  p = std::fabs(p);
215
216  if (2.0 * p < std::min(((float)3.0)*xm*q-std::fabs(tol1*q),
217  std::fabs(e*q)))
218  {
219  // Accept interpolation
220  e = d;
221  d = p / q;
222  }
223  else
224  {
225  // Interpolation failed, use bisection
226  d = xm;
227  e = d;
228  }
229  }
230  else
231  {
232  // Bounds decreasing too slowly, use bisection
233  d = xm;
234  e = d;
235  }
236
237  // Move last best guess to a
238  aPt = bPt;
239  fa = fb;
240
241  // Evaluate new trial root
242  if (std::fabs(d) > tol1)
243  {
244  bPt = bPt + d;
245  }
246  else
247  {
248  if (xm < 0) bPt = bPt - tol1;
249  else bPt = bPt + tol1;
250  }
251
252  fb = f(bPt);
253  }
254
255  if (i >= MAXITER)
256  {
257  std::cerr << "RootSolver::Brent: exceeded maximum iterations: "
258  << MAXITER << std::endl;
259  }
260
261  numIter = i;
262  return bPt;
263 }
264
265 } // End of namespace RootSolver
266
267 #include "BaseNamespaceFooter.H"
268
269 #endif
Definition: RootSolver.H:33
#define CH_assert(cond)
Definition: CHArray.H:37
IndexTM< T, N > min(const IndexTM< T, N > &a_p1, const IndexTM< T, N > &a_p2)
Definition: IndexTMI.H:396
T Brent(int &numIter, const Func &f, T aPt, T bPt, T tol=RootTr< T >::tolerance(), const unsigned MAXITER=RootTr< T >::maxIter)
Brent&#39;s root solver.
Definition: RootSolver.H:129
static double eps()
Definition: RootSolver.H:64
Definition: RootSolver.H:30
static double tolerance()
Definition: RootSolver.H:69
static float tolerance()
Definition: RootSolver.H:50
static float eps()
Definition: RootSolver.H:45
static void Abort(const char *const a_msg=m_nullString)
Print out message to cerr and exit via abort() (if serial) or MPI_Abort() (if parallel).