Chombo + EB  3.2
RootSolver.H
Go to the documentation of this file.
1 #ifdef CH_LANG_CC
2 /*
3  * _______ __
4  * / ___/ / ___ __ _ / / ___
5  * / /__/ _ \/ _ \/ V \/ _ \/ _ \
6  * \___/_//_/\___/_/_/_/_.__/\___/
8  */
9 #endif
10
11 #ifndef _ROOTSOLVER_H_
12 #define _ROOTSOLVER_H_
13
14 /******************************************************************************/
15 /**
16  * \file
17  *
18  * \brief Root solvers
19  *
20  *//*+*************************************************************************/
21
22 #include <cmath>
23 #include <algorithm>
24
25 #include "CH_assert.H"
26
28
29 namespace RootSolver
30 {
31
32 template <typename T> struct RootTr
33 {
34 };
35
36 // Default epsilon and tolerance for floats
37 template <> struct RootTr<float>
38 {
39  enum
40  {
41  maxIter = 100
42  };
43
44  static float eps()
45  {
46  return 3.0E-7;
47  }
48
49  static float tolerance()
50  {
51  return 1.0e-6;
52  }
53 };
54
55 // Default epsilon and tolerance for doubles
56 template <> struct RootTr<double>
57 {
58  enum
59  {
60  maxIter = 100
61  };
62
63  static double eps()
64  {
65  return 3.0E-15;
66  }
67
68  static double tolerance()
69  {
70  return 1.0e-12;
71  }
72 };
73
74 /*******************************************************************************
75  */
76 /// Brent's root solver
77 /**
78  * \tparam T Type for x and f(x) - must be floating point
79  * \tparam Func Function object describing function to solve where
80  * T f(x) = Func::operator()(const T& x)
81  *
82  * \param[out] numIter
83  * Number of iterations required for convergence to
84  * specified tolerance. If equal to MAXITER, the solution
85  * is not within specified tolerance.
86  * \param[in] f Instance of function to solve
87  * \param[in] aPt Lower bound
88  * \param[in] bPt Upper bound
89  * \param[in] tol Tolerance for solve - essentially the spread of the
90  * bracket. This can be specified in absolute terms, or,
91  * if given by an integer cast to T, the number of
92  * significant digits to solve for in x. The default
93  * is given by the RootTr class. Note that epsilon is
94  * also considered for specifying the spread of the
95  * bracket.
96  * \param[in] MAXITER
97  * Maximum iterations. Default (100). If reached,
98  * a message is written to cerr but the program otherwise
99  * completes
100  * \return x where f(x) = 0
101  *
102  * Example \verbatim
103  * #include <functional>
104  * #include "RootSolver.H"
105  * // Func is not allowed to be local until the C++0x standard
106  * struct Func : public std::unary_function<Real, Real>
107  * {
108  * Real operator()(const Real& a_x) const
109  * {
110  * return 5*std::pow(a_x, 5) - 3*std::pow(a_x, 3) + a_x;
111  * }
112  * };
113  * void foo()
114  * {
115  * int numIter;
116  * const Real xmin = -1.;
117  * const Real xmax = 1.;
118  * const Real x0 = RootSolver::Brent(numIter, Func(), xmin, xmax);
119  * if (numIter == RootTr<Real>::maxIter)
120  * {
121  * std::pout() << "Uh oh\n";
122  * }
123  * }
124  * \endverbatim
125  *//*+*************************************************************************/
126
127 template <typename T, typename Func>
128 T Brent(int& numIter,
129  const Func& f,
130  T aPt,
131  T bPt,
132  T tol = RootTr<T>::tolerance(),
133  const unsigned MAXITER = RootTr<T>::maxIter)
134 {
135  CH_assert(tol >= 0.);
136  CH_assert(MAXITER > 0);
137  const T eps = RootTr<T>::eps();
138  const int prec = (int)tol;
139  if (((T)prec) == tol)
140  {
141  tol = std::pow(10., -std::abs(prec));
142  }
143  numIter = -1;
144
145  // Max allowed iterations and floating point precision
146  unsigned i;
147  T c, d, e;
148  T p, q, r, s;
149
150  T fa = f(aPt);
151  T fb = f(bPt);
152
153  // Init these to be safe
154  c = d = e = 0.0;
155
156  if (fb*fa > 0)
157  {
158  MayDay::Abort("RootSolver::Brent: Root must be bracketed");
159  }
160
161  T fc = fb;
162
163  for (i = 0; i < MAXITER; i++)
164  {
165  if (fb*fc > 0)
166  {
167  // Rename a, b, c and adjust bounding interval d
168  c = aPt;
169  fc = fa;
170  d = bPt - aPt;
171  e = d;
172  }
173
174  if (Abs(fc) < Abs(fb))
175  {
176  aPt = bPt;
177  bPt = c;
178  c = aPt;
179  fa = fb;
180  fb = fc;
181  fc = fa;
182  }
183
184  // Convergence check
185  const T tol1 = 2.0 * eps * Abs(bPt) + 0.5 * tol;
186  const T xm = 0.5 * (c - bPt);
187
188  if (Abs(xm) <= tol1 || fb == 0.0)
189  {
190  break;
191  }
192
193  if (Abs(e) >= tol1 && Abs(fa) > Abs(fb))
194  {
195  // Attempt inverse quadratic interpolation
196  s = fb / fa;
197  if (aPt == c)
198  {
199  p = 2.0 * xm * s;
200  q = 1.0 - s;
201  }
202  else
203  {
204  q = fa / fc;
205  r = fb / fc;
206  p = s * (2.0 * xm * q * (q-r) - (bPt-aPt) * (r-1.0));
207  q = (q-1.0) * (r-1.0) * (s-1.0);
208  }
209
210  // Check whether in bounds
211  if (p > 0) q = -q;
212
213  p = Abs(p);
214
215  if (2.0 * p < std::min(((float)3.0)*xm*q-Abs(tol1*q),
216  Abs(e*q)))
217  {
218  // Accept interpolation
219  e = d;
220  d = p / q;
221  }
222  else
223  {
224  // Interpolation failed, use bisection
225  d = xm;
226  e = d;
227  }
228  }
229  else
230  {
231  // Bounds decreasing too slowly, use bisection
232  d = xm;
233  e = d;
234  }
235
236  // Move last best guess to a
237  aPt = bPt;
238  fa = fb;
239
240  // Evaluate new trial root
241  if (Abs(d) > tol1)
242  {
243  bPt = bPt + d;
244  }
245  else
246  {
247  if (xm < 0) bPt = bPt - tol1;
248  else bPt = bPt + tol1;
249  }
250
251  fb = f(bPt);
252  }
253
254  if (i >= MAXITER)
255  {
256  pout() << "RootSolver::Brent: exceeded maximum iterations: "
257  << MAXITER << std::endl;
258  }
259
260  numIter = i;
261  return bPt;
262 }
263
264 } // End of namespace RootSolver
265
266 #include "BaseNamespaceFooter.H"
267
268 #endif
std::ostream & pout()
Use this in place of std::cout for program output.
Definition: RootSolver.H:32
#define CH_assert(cond)
Definition: CHArray.H:37
IndexTM< T, N > min(const IndexTM< T, N > &a_p1, const IndexTM< T, N > &a_p2)
Definition: IndexTMI.H:394
T Brent(int &numIter, const Func &f, T aPt, T bPt, T tol=RootTr< T >::tolerance(), const unsigned MAXITER=RootTr< T >::maxIter)
Brent's root solver.
Definition: RootSolver.H:128
static double eps()
Definition: RootSolver.H:63
T Abs(const T &a_a)
Definition: Misc.H:53
static double tolerance()
Definition: RootSolver.H:68
static float tolerance()
Definition: RootSolver.H:49
static float eps()
Definition: RootSolver.H:44
static void Abort(const char *const a_msg=m_nullString)
Print out message to cerr and exit via abort() (if serial) or MPI_Abort() (if parallel).