//! @file rts-wapr.c
//! @author J. Marcel van der Veer
//
//! @section Copyright
//
// This file is part of VIF - vintage FORTRAN compiler.
// Copyright 2020-2025 J. Marcel van der Veer <algol68g@xs4all.nl>.
//
//! @section License
//
// This program is free software; you can redistribute it and/or modify it 
// under the terms of the GNU General Public License as published by the 
// Free Software Foundation; either version 3 of the License, or 
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful, but 
// WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for 
// more details. You should have received a copy of the GNU General Public 
// License along with this program. If not, see <http://www.gnu.org/licenses/>.

//! @section Synopsis
//!
//! Runtime support implementing the Lambert W-function.

#include <vif.h>

// The Lambert W function y = W(x) is the solution to the equation y * exp(y) = x. 
//
// Original FORTRAN77 version by Andrew Barry, S. J. Barry, Patricia Culligan-Hensley.
// Original C version by John Burkardt, distributed under the MIT license [2014].
// Adapted for VIF by J.M. van der Veer [2024].
//
// Reference:
//   Andrew Barry, S. J. Barry, Patricia Culligan-Hensley,
//   Algorithm 743: WAPR - A Fortran routine for calculating real values of the W-function,
//   ACM Transactions on Mathematical Software,
//   Volume 21, Number 2, June 1995, pages 172-181.

real_8 bisect (real_8 xx, int_4 nb, int_4 * ner, int_4 l);
real_8 crude (real_8 xx, int_4 nb);
int_4 nbits_compute ();
real_8 wapr (real_8 x, int_4 nb, int_4 * nerror, int_4 l);
real_8 _wapr (real_8 * x, int_4 * nb, int_4 * nerror, int_4 * l);

real_8 bisect (real_8 xx, int_4 nb, int_4 *ner, int_4 l)
{
// BISECT approximates the W function using bisection.
// After TOMS algorithm 743.
//
// Discussion:
//
//   The parameter TOL, which determines the accuracy of the bisection
//   method, is calculated using NBITS (assuming the final bit is lost
//   due to rounding error).
//
//   N0 is the maximum number of iterations used in the bisection
//   method.
//
//   For XX close to 0 for Wp, the exponential approximation is used.
//   The approximation is exact to O(XX^8) so, depending on the value
//   of NBITS, the range of application of this formula varies. Outside
//   this range, the usual bisection method is used.
//
// Parameters:
//
//   Input, real_8 XX, the argument.
//
//   Input, int_4 NB, indicates the branch of the W function.
//   0, the upper branch;
//   nonzero, the lower branch.
//
//   Output, int_4 *NER, the error flag.
//   0, success;
//   1, the routine did not converge.  Perhaps reduce NBITS and try again.
//
//   Input, int_4 L, the offset indicator.
//   1, XX represents the offset of the argument from -exp(-1).
//   not 1, XX is the actual argument.
//
//   Output, real_8 BISECT, the value of W(X), as determined

  const int_4 n0 = 500;
  int_4 i;
  real_8 d, f, fd, r, test, tol, u, x, value = 0.0;
  static int_4 nbits = 0;
  *ner = 0;
  if (nbits == 0) {
    nbits = nbits_compute ();
  }
  if (l == 1) {
    x = xx - exp (-1.0);
  } else {
    x = xx;
  }
  if (nb == 0) {
    test = 1.0 / pow (pow (2.0, nbits), (1.0 / 7.0));
    if (fabs (x) < test) {
      return x * exp (-x * exp (-x * exp (-x * exp (-x * exp (-x * exp (-x))))));
    } else {
      u = crude (x, nb) + 1.0e-3;
      tol = fabs (u) / pow (2.0, nbits);
      d = fmax (u - 2.0e-3, -1.0);
      for (i = 1; i <= n0; i++) {
	r = 0.5 * (u - d);
	value = d + r;
// Find root using w*exp(w)-x to avoid ln(0) error.
	if (x < exp (1.0)) {
	  f = value * exp (value) - x;
	  fd = d * exp (d) - x;
	}
// Find root using ln(w/x)+w to avoid overflow error.
	else {
	  f = log (value / x) + value;
	  fd = log (d / x) + d;
	}
	if (f == 0.0) {
	  return value;
	}
	if (fabs (r) <= tol) {
	  return value;
	}
	if (0.0 < fd * f) {
	  d = value;
	} else {
	  u = value;
	}
      }
    }
  } else {
    d = crude (x, nb) - 1.0e-3;
    u = fmin (d + 2.0e-3, -1.0);
    tol = fabs (u) / pow (2.0, nbits);
    for (i = 1; i <= n0; i++) {
      r = 0.5 * (u - d);
      value = d + r;
      f = value * exp (value) - x;
      if (f == 0.0) {
	return value;
      }
      if (fabs (r) <= tol) {
	return value;
      }
      fd = d * exp (d) - x;
      if (0.0 < fd * f) {
	d = value;
      } else {
	u = value;
      }
    }
  }
// The iteration did not converge.
  *ner = 1;
  return value;
}

real_8 crude (real_8 xx, int_4 nb)
{
// CRUDE returns a crude approximation for the W function.
//
// Parameters:
//
//   Input, real_8 XX, the argument.
//
//   Input, int_4 NB, indicates the desired branch.
//   * 0, the upper branch;
//   * nonzero, the lower branch.
//
//   Output, real_8 CRUDE, the crude approximation to W at XX.

  real_8 an2, reta, t, ts, zl;
  static int_4 init = 0;
  static real_8 c13, em, em2, em9, eta, s2, s21, s22, s23;
// Various mathematical constants.
  if (init == 0) {
    init = 1;
    em = -exp (-1.0);
    em9 = -exp (-9.0);
    c13 = 1.0 / 3.0;
    em2 = 2.0 / em;
    s2 = sqrt (2.0);
    s21 = 2.0 * s2 - 3.0;
    s22 = 4.0 - 3.0 * s2;
    s23 = s2 - 2.0;
  }
// Crude Wp.
  if (nb == 0) {
    if (xx <= 20.0) {
      reta = s2 * sqrt (1.0 - xx / em);
      an2 = 4.612634277343749 * sqrt (sqrt (reta + 1.09556884765625));
      return reta / (1.0 + reta / (3.0 + (s21 * an2 + s22) * reta / (s23 * (an2 + reta)))) - 1.0;
    } else {
      zl = log (xx);
      return log (xx / log (xx / pow (zl, exp (-1.124491989777808 / (0.4225028202459761 + zl)))));
    }
  } else {
// Crude Wm.
    if (xx <= em9) {
      zl = log (-xx);
      t = -1.0 - zl;
      ts = sqrt (t);
      return zl - (2.0 * ts) / (s2 + (c13 - t / (270.0 + ts * 127.0471381349219)) * ts);
    } else {
      zl = log (-xx);
      eta = 2.0 - em2 * xx;
      return log (xx / log (-xx / ((1.0 - 0.5043921323068457 * (zl + 1.0)) * (sqrt (eta) + eta / 3.0) + 1.0)));
    }
  }
}

int_4 nbits_compute ()
{
// NBITS_COMPUTE computes the mantissa length minus one.
//
// Discussion:
//
//   NBITS is the number of bits (less 1) in the mantissa of the
//   floating point number number representation of your machine.
//   It is used to determine the level of accuracy to which the W
//   function should be calculated.
//
// Parameters:
//
//   Output, int_4 NBITS_COMPUTE, the mantissa length, in bits, minus one.
//
  int m = 14;
  return _i1mach (&m) - 1;
}

real_8 wapr (real_8 x, int_4 nb, int_4 *nerror, int_4 l)
{
// WAPR approximates the W function.
//
// Discussion:
//
//   The call will fail if the input value X is out of range.
//   The range requirement for the upper branch is:
//     -exp(-1) <= X.
//   The range requirement for the lower branch is:
//     -exp(-1) < X < 0.
//
// Parameters:
//
//   Input, real_8 X, the argument.
//
//   Input, int_4 NB, indicates the desired branch.
//   * 0, the upper branch;
//   * nonzero, the lower branch.
//
//   Output, int_4 *NERROR, the error flag.
//   * 0, successful call.
//   * 1, failure, the input X is out of range.
//
//   Input, int_4 L, indicates the interpretation of X.
//   * 1, X is actually the offset from -(exp-1), so compute W(X-exp(-1)).
//   * not 1, X is the argument; compute W(X);
//
//   Output, real_8 WAPR, the approximate value of W(X).

  int_4 i;
  real_8 an2, delx, eta, reta, t, temp, temp2, ts, xx, zl, zn, value = 0.0;
  static int_4 init = 0, nbits, niter = 1;
  static real_8 an3, an4, an5, an6, c13, c23, d12, em, em2, em9;
  static real_8 s2, s21, s22, s23, tb, x0, x1;
  *nerror = 0;
  if (init == 0) {
    init = 1;
    nbits = nbits_compute ();
    if (56 <= nbits) {
      niter = 2;
    }
// Various mathematical constants.
    em = -exp (-1.0);
    em9 = -exp (-9.0);
    c13 = 1.0 / 3.0;
    c23 = 2.0 * c13;
    em2 = 2.0 / em;
    d12 = -em2;
    tb = pow (0.5, nbits);
    x0 = pow (tb, 1.0 / 6.0) * 0.5;
    x1 = (1.0 - 17.0 * pow (tb, 2.0 / 7.0)) * em;
    an3 = 8.0 / 3.0;
    an4 = 135.0 / 83.0;
    an5 = 166.0 / 39.0;
    an6 = 3167.0 / 3549.0;
    s2 = sqrt (2.0);
    s21 = 2.0 * s2 - 3.0;
    s22 = 4.0 - 3.0 * s2;
    s23 = s2 - 2.0;
  }
  if (l == 1) {
    delx = x;
    if (delx < 0.0) {
      *nerror = 1;
      RTE ("wapr", "offset X must be non-negative");
    }
    xx = x + em;
  } else {
    if (x < em) {
      *nerror = 1;
      return value;
    } else if (x == em) {
      value = -1.0;
      return value;
    }
    xx = x;
    delx = xx - em;
  }
// Calculations for Wp.
  if (nb == 0) {
    if (fabs (xx) <= x0) {
      value = xx / (1.0 + xx / (1.0 + xx / (2.0 + xx / (0.6 + 0.34 * xx))));
      return value;
    } else if (xx <= x1) {
      reta = sqrt (d12 * delx);
      value = reta / (1.0 + reta / (3.0 + reta / (reta / (an4 + reta / (reta * an6 + an5)) + an3))) - 1.0;
      return value;
    } else if (xx <= 20.0) {
      reta = s2 * sqrt (1.0 - xx / em);
      an2 = 4.612634277343749 * sqrt (sqrt (reta + 1.09556884765625));
      value = reta / (1.0 + reta / (3.0 + (s21 * an2 + s22) * reta / (s23 * (an2 + reta)))) - 1.0;
    } else {
      zl = log (xx);
      value = log (xx / log (xx / pow (zl, exp (-1.124491989777808 / (0.4225028202459761 + zl)))));
    }
  }
// Calculations for Wm.
  else {
    if (0.0 <= xx) {
      *nerror = 1;
      return value;
    } else if (xx <= x1) {
      reta = sqrt (d12 * delx);
      value = reta / (reta / (3.0 + reta / (reta / (an4 + reta / (reta * an6 - an5)) - an3)) - 1.0) - 1.0;
      return value;
    } else if (xx <= em9) {
      zl = log (-xx);
      t = -1.0 - zl;
      ts = sqrt (t);
      value = zl - (2.0 * ts) / (s2 + (c13 - t / (270.0 + ts * 127.0471381349219)) * ts);
    } else {
      zl = log (-xx);
      eta = 2.0 - em2 * xx;
      value = log (xx / log (-xx / ((1.0 - 0.5043921323068457 * (zl + 1.0)) * (sqrt (eta) + eta / 3.0) + 1.0)));
    }
  }
  for (i = 1; i <= niter; i++) {
    zn = log (xx / value) - value;
    temp = 1.0 + value;
    temp2 = temp + c23 * zn;
    temp2 = 2.0 * temp * temp2;
    value = value * (1.0 + (zn / temp) * (temp2 - zn) / (temp2 - 2.0 * zn));
  }
  return value;
}

real_8 _wapr (real_8 *x, int_4 *nb, int_4 *nerror, int_4 *l)
{
// F77 API.
  return wapr (*x, *nb, nerror, *l);
}
