Emulating the FMAdd Instruction, Part 1: 32-bit Floats
A thing that I had to do at work is write an emulation of the FMAdd (fused multiply-add) instruction for hardware where it wasn't natively supported (specifically I was writing a SIMD implementation, but the idea is the same), and so I thought I'd share a little bit about how FMAdd works, since I've already been posting about how float rounding works.
So, screw it, here we go with another unnecessarily technical, mathy post!
What is the FMAdd Instruction?
A fused multiply-add is basically doing a multiply and an add as a single operation, and it gives you the result as if it were computed with infinite precision and then rounded down at the final result. FMAdd computes (a * b) + c
without intermediate floating-point error being introduced:
float FMAdd(float a, float b, float c)
{
// ??? Somehow do this with no intermediate rounding
return (a * b) + c;
}
Computing it normally (using the code above) for some values will get you double rounding (explained in a moment) which means you might be an extra bit off (or, more formally, one ULP) from where your actual result should be. An extra bit doesn't sound like a lot, but it can add up over many operations.
Fused multiply-add avoids this extra rounding, making it more accurate than a multiply followed by a separate add, which is great! (It can also be faster if it's supported by hardware but, as you'll see, computing it without a dedicated instruction on the CPU is actually surprisingly spendy, especially once you get into doing it for 64-bit floats, but sometimes you need precision instead of performance).
Double Rounding
Double rounding happens when the intermediate value rounds down (or up), then the final result also rounds in the same direction - but because of the first rounding, actually overshoots the correctly-rounded final value by a bit.
Here's an example using two successive sums of some 4-bit float values. We'll do the following sum (in top-down order):
1.000 * 2^4
+ 1.001 * 2^0
+ 1.100 * 2^0
The first sum, done with "infinite" internal precision, looks like this:
1.000 * 2^4
+ 1.001 * 2^0
-----------
1.000 0000 * 2^4
+ 0.000 1001 * 2^4
----------------
1.000 1001 * 2^4
If we were to then use that result directly (with no intermediate rounding) and do the second sum, only rounding the final result:
1.000 1001 * 2^4
+ 0.000 1100 * 2^4
----------
= 1.001 0101 * 2^4
-> 1.001 * 2^4 // Rounds down
The final result rounds (to nearest) to 1.001
.
However, if we were to round that intermediate value to 4 bits first, we'd get this:
1.000 1001 * 2^4
-> 1.001 * 2^4 // Rounded up
+ 0.000 1100 * 2^4
----------
1.001 1100 * 2^4
-> 1.010 * 2^4 // Up again
In this one, we end up with 1.010
instead of 1.001
because of the intermediate rounding, which pushed us past the correctly-rounded final result.
How to Pretend That You Have Infinite Precision
Okay, for FMAdd we want to calculate a multiply, and then somehow throw an add in there and have it act as if we didn't lose any precision on the multiply.
First we're going to handle the case of 32-bit floats (singles) because it's a wildly simpler case on CPUs that have 64-bit floats (doubles).
(also, sorry in advance, the term "double" for a "double-precision float" and the "double" in "double rounding" are two different instances of "double" but I've written so much of this post and like hell am I changing it now so hopefully it's not too confusing)
The immediately obvious thing to try to get an accurate single-precision FMA is "hey, what if we do the multiply and add as doubles and then round the result back down to a single":
float FMAdd(float a, float b, float c)
{
// Do the math as 64-bit floats and truncate at the end.
// Surely that's good enough, right?
return float((double(a) * double(b)) + double(c));
}
While that gives a much better result than doing it as pure 32-bits, it actually can still have double rounding. But where does the extra rounding come from, in this case?
The multiply itself isn't the source of the first rounding: Surprisingly (to me, at least): casting two singles to doubles and multiplying those together always results in an exact answer - this is because each of the single-precision values has 24 bits of precision, but a double can store 53 bits of precision, which is more than enough to store the result of multipling two singles (2 * 24
bits of precision max). Since floats are stored as:
sign * 1.mantissa * 2^(exponent)
...it means we're multiplying two numbers of the form 1.xxxxxxxxxx
and 1.yyyyyyyyy
together then adding the exponents together to get the new number, so unlike addition and subtraction (where, say, 1 + 1*10^60
requires a ton of extra precision), if two float numbers have wildly different exponents it doesn't actually matter because the exponents and significand values are handled separately.
To illustrate this, let's pretend we have two 4-digit (base 10) numbers and we multiply them and store the result using 8 digits (double precision):
(1.234 * 2^1) * (1.457 * 2^100)
-> (1.2340000 * 1.4570000) * (2^1 * 2^100)
= 1.7979380 * 2^101 // no rounding here!
Great, so the double-precision multiply is fine and introduces no rounding at all. So then how do we get double rounding?
As mentioned above, an add (or subtract) can introduce rounding:
(1.234 * 10^0) + (1.457 * 10^9)
-> (1.2340000 * 10^0) + (1.4570000 * 10^9)
= (1.45700001234 * 10^9) // Too many digits!
-> 1.45700000 // Rounded to nearest here
This rounding happens at double precision (so well below the threshold of our target 32-bit result), but there's still rounding, and then the value is rounded again when converted back down to single precision. That's the double rounding and the source of a potential error.
Okay, so, double rounding is bad? Kinda! But it turns out there is a way to introduce a new rounding mode to use for the first rounding that, in the right situations, does not introduce any additional error and ensures that your final result is correct.
A New Rounding Mode?
(This technique is based off of the paper Emulation of FMA and correctly-rounded sums: proved algorithms using rounding to odd by Sylvie Boldo and Guillaume Melquiond; if you want the full technical details of this whole process, that's where you'll find them. Believe it or not, I'm actually trying to go into less detail!)
The key to eliminating the extra precision loss is by using a non-standard rounding mode: rounding to odd. Standard floating point rounding calculates results with some additional bits of precision (three bits, to be precise), and then rounds based on the result (usually using "round to nearest with round to even on ties", although that detail doesn't end up mattering here - this technique works with any standard rounding mode).
So, assume that we have some way of calculating a double precision addition and also having access to the error between the calculated result and the mathematically exact result. Given those two values we can perform a Round To Odd step:
double RoundToOdd(double value, double errorTerm)
{
if (errorTerm != 0.0 // if the result is not exact
&& LowestBitOfMantissa(value) == 0) //and mantissa is even
{
// We need to round, so round either up or down to odd
if (errorTerm > 0)
{
// Round up to an odd value
value = AddOneBitToMantissa(value);
}
else // (errorTerm < 0)
{
// Round down to an odd value
value = SubtractOneBitFromMantissa(value);
}
}
}
Basically: if we have any error at all, and the mantissa is currently even, either add or subtract a single bit's worth of mantissa, based on the sign of the error.
(In practice, I found that I also had to ensure the result was not Infinity before doing this operation, since I implemented this using some bitwise shenanigans that would end up "rounding" Infinity to NaN, so, you know, watch out for that).
Why Does Odd-Rounding the Intermediate Value Work?
Round to odd works as long as we have more bits of value than the final result - specifically we need at least two extra bits. Standard float rounding makes use of something called a "sticky" bit - basically the lowest bit of the extra precision is a 1 if any of the bits below it would have been 1.
And, hey, that is basically what "round to odd" does!
- If the mantissa is odd, regardless of whether there's error or not the lowest bit is already odd.
- If the error was positive and the mantissa was even, we set the lower bit to 1 anyway, effectively stickying (yeah that's a word now) all the error bits below it.
- If the error was negative and the mantissa was even, we subtract 1 from the mantissa, making the lower bit odd, and effectively sticky since some of the digits below it are also 1s.
Effectively, round to odd is just "emulate having a sticky bit at the bottom of your intermediate result" - that way, you have a guaranteed tiebreaker for the final rounding step.
But note that I said it requires you to have at least two extra bits. In the case of our using-doubles-instead-of-singles intermediate addition, good news: we have way more than two extra bits - our intermediate value is a whole-ass double-precision float, so we have 29 extra bits vs. our single-precision final value and (mathematically speaking) 29 is greater than 2.
So, for the true single-precision FMAdd instruction we need to do the following:
float FMAdd(float a, float b, float c)
{
double product = double(a) * double(b); // No rounding here
// Calculate our sum, but somehow get the error along with it
(double sum, double err) = AddWithError(product, c);
// Round our intermediate value to odd
sum = RoundToOdd(sum, err);
// Final rounding here, which now does the correct thing and gives us
// a properly-rounded final result (as if we'd used infinite bits)
return float(sum);
}
That's it! ...wait, what's that AddWithError
function, we haven't even--
Calculating An Exact Addition Result
Right, we need to calculate that intermediate addition along with some accurate error term. It turns out it's possible to calculate a set of numbers, sum and error where mathematicallyExactSum = sum + error
.
For this, we have to dive back to July of 1971 and check out the actually typewritten paper A Floating-Point Technique For Extending the Available Precision by T.J. Dekker. Give that a read if you want way more details on this whole thing.
Calculating the error term of adding two numbers (I'll use x
and y
) is relatively straightforward if |x| > |y|
:
sum = x + y;
err = y - (sum - x);
(this is equation 4.14 in the linked paper)
This is just a different ordering of (x + y) - sum
that preserves accuracy: due to the nature of the values involved in these subtractions (sum
's value is directly related to those of x
and y
, and y
is smaller than x
), it turns out that each of those subtractions is an exact result (the paper has a proof of this, and it's a lot so I'm not going to expand on that here), so we get the precise difference between the calculated sum and the real sum.
But this only works if you know that x
's magnitude is larger than (or equal to) y
's. If you don't know which of the two values has a larger magnitude, you can do a bit more work and end up with:
(double sum, double err) AddWithError(
double x,
double y)
{
double sum = a + b;
double intermediate = sum - x;
double err1 = y - intermediate;
double err2 = x - (sum - intermediate);
return (sum, err1 + err2);
}
(This is effectively the expanded version of listing 4.16 from the linked paper)
err1
here is the same as the value in the first version we calculated (a precision-preserving rewrite of(x + y) - sum
)err2
is, mathematically,x - (sum - (sum - x))
or0
; its goal is to calculate the error involved in calculating err1, since without the|x| > |y|
guarantee those subtractions might NOT be exact ... but these ones will be.- Thus, summing these two error terms together gives us a final, precise error term.
(More details in the paper, hopefully this isn't too glossed over that it loses any meaning)
Finally, the End (For Single-Precision Floats)
So, yeah, that's how you implement the FMAdd instruction for single-precision floats on a machine that has double-precision support:
- Calculate the double-precision product of
a
andb
- Add this product to
c
to get a double-precision sum - Calculate the error of the sum
- Use the error to odd-round the sum
- Round the double-precision sum back down to single precision
But what if you have to calculate FMAdd for double-precision floats? You can't easily just cast up to, like, quad-precision floats and do the work there, so what now? Can you still do this?
The answer is yes, but it's a lot more work, and that's what the next post is about.