JoshJers' Ramblings

Infrequently-updated blog about software development, game development, and music

Emulating the FMAdd Instruction, Part 2: 64-bit Floats

(This post follows Part 1: 32-bit floats and will make very little sense without having read that one first. Honestly, it might make little sense having read that one first, I dunno!)

Last time we went over how to calculate the results of the FMAdd instruction (a fused-multiply-add calculated as if it had infinite internal precision) for 32-bit single-precision float values:

  • Calculate the double-precision product of a and b
  • Add this product to c to get a double-precision sum
  • Calculate the error of the sum
  • Use the error to odd-round the sum
  • Round the double-precision sum back down to single precision

This requires casting up to a 64-bit double-precision float to get extra bits of precision. But what if you can't do that? What if you're using doubles? You can't just (in most cases) cast up to a quad-precision float. So what do you do?

Double-Precision FMAdd

(Like the 32-bit version, this is based on Emulation of FMA and correctly-rounded sums: proved algorithms using rounding to odd by Sylvie Boldo and Guillaume Melquiond, there are additional details in there)

To do this natively as doubles, we need to invent a new operation: MulWithError. This is the multiplication equivalent of the AddWithError function from the 32-bit solution:

(double prod, double err) MulWithError(double x, double y)
{
  double prod = x * y;
  double err = // ??? how do we do this
  return (prod, err);
}

We'll get to how to implement that in a moment, but first we'll walk through how to use that function to calculate a proper FMAdd.

We need to do the following:

  • Calculate the product of a and b and the error of that product
  • Calculate the sum of that product and c (giving us a * b + c) and the error of this sum
    • We're not using the error of the product ... yet
  • Add the two error terms (product error and sum error) together, rounding the result to odd
  • Add this summed error term to our actual result, which will round normally.

In code, that looks like this:

// Start with an "OddRoundedToAdd" helper since we do
//  this operation frequently
double OddRoundedAdd(double x, double y)
{
  (double sum, double err) = AddwithError(x, y);
  return RoundToOdd(sum, err);
}

double FMAdd(double a, double b, double c)
{
  (double ab, double abErr) = MulWithError(a, b);
  (double abc, double abcErr) = AddWithError(ab, c);

  // Odd-round the sum of the two errors before 
  //  adding it in to the final result.
  double err = OddRoundedAdd(abErr, abcErr);
  return abc + err;
}

By keeping the error terms from both the product and the sum, we have kept all of the exact result. That is, we can assemble the mathematically-exact result given enough precision by doing abc + abErr + abcErr.

But we can't do infinite-precision addition of three values. However, we can odd-round an intermediate result, the same way we did with the single-precision case.

In this case, we know that abErr and abcErr both (necessarily) have much lower magnitudes than the final result, as each error value's highest bit is lower than the lowest bit of the mantissa of their respective operations. So, if we odd-round the sum of these two values, it actually effectively fulfills the condition of having more bits of precision than the final result. Thus, if we add the error terms together with odd rounding, the odd-rounded fake-sticky final digit will be taken into account by the actual sticky bit used when doing the final sum of the result and error terms.

I hope that makes sense?

Breaking Down Multiplication With Error

(Like AddWithError, this is based on A Floating-Point Technique For Extending the Available Precision by T.J. Dekker)

So how do we calculate the error term of a 64-bit multiply? We can't use 128-bit values, but what we can do is break each 64-bit value up into two values, each with less bits of precision.

We'll break x and y (our two multiplicands) up into high and low values, where:

x = xh + xl;
y = yh + yl;

We do this by breaking the double's mantissa up:

  • xh contains the top 25 bits of x's mantissa (plus its implied 1, giving it 26 bits of precision).
  • xl contains the bottom 27 bits of x's mantissa. The highest-set '1' in this mantissa will become the implied 1 bit that's part of the standard floating-point format, so this value will have 27 bits of precision, max.
  • yh and yl are the same, but for y.
(double h, double l) Split(double v)
{
  // In C++ this Zero function can be implemented by  masking off
  //  the bottom 27 bits by casting to a 64-bit int:
  // constexpr uint64_t Mask = ~0x07ff'ffff;
  // double h = std::bit_cast<double>(
  //              std::bit_cast<uint64_t>(v) & Mask);
  double h = ZeroBottom27BitsOfMantissa(v);

  // We can get the lower bits of the mantissa (correctly normalized and
  //  with correct signs) by subtracting the extracted upper bits from 
  //  the original value.
  double l = v - h;
  return (h, l);
}

What does this split give us? Well, we can now break the multiplication up into a sum of multiplies that now each have enough bits of precision to be exactly representable (27BitValueA * 27BitValueB == 54BitValue, which fits perfectly in a double (with the implied 1 bit), using our old friend from Algebra, FOIL:

    x * y 
= (xh + xl) * (yh + yl) 
= xh*yh + xh*yl + xl*yh + xl*yl;

We can't actually do those adds directly, but what we can do is similar to how we did AddWithError: use a sequence of precision-preserving operations to calculate the difference between that idealized result and our rounded product:

(double prod, double err) MulWithError(double x, double y)
{
  double prod = x * y;

  (xh, xl) = Split(x);
  (yh, yl) = Split(y);

  // Parentheses to demonstrate the precise order these 
  //  operations must occur in
  double err = (((xh*yh - prod) + xh*yl) + xl*yh) + xl*yl;
  return (prod, err);
}

It works like this:

  • Calculate the (rounded) product of x and y
  • Subtract that rounded product from the product of xh and yh
    • These should have roughly the same magnitude (and definitely the same sign) so this is a precision-preserving subtraction.
    • Since |xh * yh| <= rounded(|x * y|) (because xh and yh are truncated versions of x and y and thus have lower magnitudes) this is a smaller - larger operation and we'll get a result with a sign opposite that of the final product.
  • Keep adding in next-lower-magnitudes of values, which will continue to preserve precision
    • (because we have a value that is opposite-sign these are effectively subtractions, in the same way that a + -b is)
    • It's also worth noting here that xh*yl and xl*yh will have equivalent magnitudes so the order that you add them in doesn't matter, as long as they're both after xh*yh and before xl*yl

Once you've done that, you have the computed product as well as the error term, and we can then follow our FMAdd algorithm above to calculate the FMAdd.

So, that's it, we're done, right?

Edge Cases

Nope! Well, yes if you just wanted the gist, but now it's time to get into all those annoying implementation details that the papers this is based on completely glossed over. Here's where it gets ugly (unless you thought it was already ugly, in which case, sorry, it's about to get worse somehow).

In our single-precision case, we didn't have to worry about exponent overflow or underflow because we were using double-precision intermediates, which not only have additional mantissa range, but also additional exponent range.

It's possible that the product of a * b (an intermediate value in our calculation) goes out of range of what a double can represent, but that the addition of c might bring the final result back into range (which can happen when the sign of c is opposite the sign of a * b). This causes a different set of errors on either end:

  • If a*b is too large to be represented, it turns into infinity which means that adding c in will just leave it as infinity even though the final result should have been a representable value (albeit one with a very large magnitude)
  • If a*b is too small to be represented, it will go subnormal which means bits of the intermediate result will slide off the bottom of the mantissa and we lose bits of information, which can causes us to round incorrectly at our final result.

To solve this, we'll introduced a bias into the calculation, for when the value goes very small or very large:

double CalculateFMAddBias(double a, double b, double c)
{
  // Calculate what our final result would be if we just did it normally
  double testResult = Abs(a * b + c);

  if (testResult < Pow2(-500) && Max(a, b, c) < Pow2(800))
  {
    // Our result is very small and our maximum value is not so large 
    //  that we'll blow up with a bias, so bias our values up to
    //  ensure we don't go subnormal in our intermediate result
    return Pow2(110);
  }
  else if (IsInfinite(testResult))
  {
    // We hit infinity, but that might be due to exponent overflow,
    //  so bias everything down (this may cause c to go subnormal, 
    //  but if that's the case then a*b on its own is infinity and
    //  so it won't affect the final result in any way)
    return Pow2(-55);
  }
  else
  {
    // No bias needed
    return 1.0;
  }
}

For any results that aren't extreme, the bias will remain 1.0 but, for values at the extremes, we'll scale our intermediates down (using powers of 2 which only affect the exponent and not the mantissa) into a range such that we can't temporarily poke outside of range. Also note that my choices of powers of 2 are not perfectly chosen, I didn't bother trying to figure out the exact right biases/thresholds so I just picked ones that I knew were good enough.

So then we do our FMAdd calculation as before, but with the bias introduced (and then backed out at the end):

// Do our multiplication with the bias applied to 'a'
//  (the choice of applying it to 'a' vs 'b' is completely
//  arbitrary)
(double ab, double abErr) = MulWithError(a * bias, b);

// Then the sum with the bias applied to 'c'
(double abc, double abcErr) = AddWithError(ab, c * bias);

err = OddRoundedAdd(abErr, abcErr);

// Calculate our final result then un-bias the result.
return (abc + err) / bias;

Alright, we've avoided both overflow and underflow and everything is great, right?

Two (Point Five) Last Annoying Implementation Details

Nope, sorry again! It turns out there are still two cases we need to deal with.

Case 1: Infinity or NaN even with the bias

If our result (without error applied) hits infinity even with the avoid-infinity bound, then we should just go ahead and return now to avoid Causing Problems Later (that is, turning what should be infinity into a NaN). And if it's already NaN we can just return now because it's going to be NaN forever.

Except, there's one additional necessary check here, for a case caught by dzaima over on Bluesky): in the event a and b are finite numbers but a * b blows up to infinity, and then c is the opposite infinity, the correct return value is whichever infinity (positive or negative) c is, so in our early-out check has to catch that case as well:

  if (IsInfiniteOrNaN(abc))
  { 
    // If we got NaN (or Inf, which won't affect the output) and 
    //  a and b are both finite but c is infinite, return c (without
    //  this check, we will incorrectly return NaN instead of -Inf
    //  for FMAdd(1e200, 1e200, -Infinity))
    if (IsInfinite(c) && !IsInfiniteOrNaN(a) && !IsInfiniteOrNaN(b))
      { return c; }

    // Otherwise, return whichever Inf or NaN we got directly;
    return abc;
  }

Case 2: Subnormal Results

If our result is subnormal (after the bias is backed out), then it's going to lose bits of precision as it shifts down (because the exponent can't go any lower so instead the value itself shifts down the mantissa), which means whoops here's another rounding step, and the dreaded double-rounding returns.

In this case we need to actually odd-round the addition of the error term as well, so that when the bias is backed out and it rounds, it does the correct thing:

// Multiply the smallest-representable normalized value by our avoid-
//  subnormal bias. Any (biased) value below this will go subnormal.
//  (In production code it'd be nicer to use something like
//  std::numeric_limits instead of hard-coding -1022)
const double SubnormThreshold = Pow2(-1022) * AvoidDenormalBias;

if (bias == AvoidSubnormalBias && Abs(abc) < SubnormThreshold)
{
  // Odd-round the addition of the error so that the rounding that 
  //  happens on the divide by the bias is correct.
  (double finalSum, finalSumErr) = AddWithError(abc, err);
  finalSum = RoundToOdd(finalSum, finalSumErr);
  return finalSum / bias;
}

And this almost works, except there's one more annoying case, and that's where our result is going subnormal, but only by exactly one bit. Remember that the odd-rounding trick only works if we have two or more bits so that the final rounding works properly, but in this case we're truncating the mantissa by exactly one bit, so we have to do even more work:

  • Split the value that will be shifting down into a high and low part (same as we did for the multiply)
  • Add our error term to the low part of it
    • This preserves additional bits of the error term since we gave ourselves more headroom by removing the upper half of its mantissa
  • Remove the bias from both the high and low parts separately
    • Removing the bias from the high part doesn't round since we know the lowest bit is 0
    • Removing the bias from the low part applies the actual final rounding (correctly) since we gave ourselves more bits to work with
  • Sum the halves back together and return that as our final result
    • This sum is (thankfully) perfectly representable by the final precision and doesn't introduce any additional error.
const double OneBitSubnormalThreshold = 
  OneBitSubnormalThreshold * 0.5;
if (Abs(finalResult.result) >= k_oneBitDenormThreshold)
{
  // Split into halves
  (rh, rl) = Split(finalSum);

  // Add the error term into the low part of the split
  rl = OddRoundedAddition(rl, finalSumErr);

  // Scale them both down by the bias. Note that 
  //  the rh division cannot round since the lowest bit
  //  is 0
  rh /= bias;

  // This division is what actually introduces the final
  //  rounding (correctly, since we gave ourselves more
  //  bits to work with)
  rl /= bias;

  // This sum is perfectly representable by the final
  //  precision and will not introduce additional error.
  return rh + rl;
}

OMG Are We Done Now?

As far as I'm aware, those are all the implementation details to doing a 64-bit double-precision FMAdd implementation. It's conceptually not that much more complicated than the 32-bit one, but mechanically it's worse, plus there are those fun extra edge cases to consider.

Here's the final code:

(double h, double l) Split(double v)
{
  double h = ZeroBottom27BitsOfMantissa(v);
  double l = v - h;
  return (h, l);
}

(double prod, double err) MulWithError(
  double x, 
  double y)
{
  double prod = x * y;

  (xh, xl) = Split(x);
  (yh, yl) = Split(y);
  double err = 
    (((xh*yh - prod) + xh*yl) + xl*yh) + xl*yl;
  return (prod, err);
}

double OddRoundedAdd(double x, double y)
{
  (double sum, double err) = AddwithError(x, y);
  return RoundToOdd(sum, err);
}

double FMAdd(double a, double b, double c)
{
  const double AvoidSubnormalBias = Pow2(110);
  double bias = 1.0;
  {
    // Calculate our final result as if done normally
    double testResult = Abs(a * b + c);

    // Bias if the result goes too low or too high
    if (testResult < Pow2(-500) && Max(a, b, c) < Pow2(800))
      { bias = AvoidSubnormalBias; } // too low
    else if (IsInfinite(testResult))
      { bias = Pow2(-55); } // too high
  }

  // Calculate using our bias
  (double ab, double abErr) = MulWithError(a * bias, b);
  (double abc, double abcErr) = AddWithError(ab, c * bias);

  // Check for infinity or NaN and return early
  if (IsInfiniteOrNaN(abc))
  { 
    // Handle the case of "a multiply of two finite values hit infinity
    //  even *with* the bias, but c is the opposite infinity" case and
    //  return the correct result  of "c"
    if (IsInfinite(c) && !IsInfiniteOrNaN(a) && !IsInfiniteOrNaN(b))
      { return c; }

    // Otherwise just return the inf or nan directly
    return abc; 
  }

  // Odd-round the intermediate error resultt
  double err = OddRoundedAdd(abErr, abcErr);

  // Multiply the smallest-representable normalized value by our avoid-
  //  subnormal bias. Any (biased) value below this will go subnormal
  const double SubnormThreshold = Pow2(-1022) * AvoidSubnormalBias;

  if (bias == AvoidSubnormalBias && Abs(abc) < SubnormThreshold)
  {
    (double finalSum, finalSumErr) = AddWithError(abc, err);

    // This is half of SubnormThreshold. Any value between SubnormThresold
    //  and this value will only lose a single bit of precision when
    //  the bias is removed, which requires some extra care
    const double OneBitSubnormalThreshold = 
      OneBitSubnormalThreshold * 0.5;
    if (Abs(finalSum) >= OneBitSubnormalThreshold)
    {
      // Split into halves
      (rh, rl) = Split(finalSum);

      // Add the error term into the LOW part of our split value
      rl = OddRoundedAdd(rl, finalSumErr);

      // Divide out the bias from both halves (which will cause rl to
      //  round to its final, correctly-rounded value) then sum them 
      //  together (which is perfectly representable).
      rh /= bias;
      rl /= bias;
      return rh + rl;
    }
    else
    {
      // For more-than-one-bit subnormals, we do an odd-rounded addition of
      //  the error term and then divide out the bias, doing full rounding
      //  just once.
      finalSum = RoundToOdd(finalSum, finalSumErr);
      return finalSum / bias;
    }
  }
  else
  {
    // Not subnormal, so we can calculate our final result normally and un-
    //  bias the result.
    return (abc + err) / bias;
  }
}

Compare that to the 32-bit version and you can see why this one got its own post:

float FMAdd(float a, float b, float c)
{
  double product = double(a) * double(b);
  (double sum, double err) = AddWithError(product, c);
  sum = RoundToOdd(sum, err);
  return float(sum);
}

Hopefully you never have to actually implement this yourself, but if you do? I hope this helps.