Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed-point arithmetic (nfixed) improvements #2065

Open
fredrik-johansson opened this issue Sep 6, 2024 · 4 comments
Open

Fixed-point arithmetic (nfixed) improvements #2065

fredrik-johansson opened this issue Sep 6, 2024 · 4 comments

Comments

@fredrik-johansson
Copy link
Collaborator

fredrik-johansson commented Sep 6, 2024

The best way to do bulk few-limb arithmetic on real numbers seems to be to convert to fixed-point form for things like matrix multiplication, FFT and polynomial multiplication. This generally means using somewhat higher precision than floating-point for equivalent accuracy, but being able to skip all the shifts, normalisations and branches of floating-point seems to make up for this drawback. Additions are notably cheaper in fixed-point, and this is especially important for subclassical algorithms (Karatsuba, Strassen, ...) where the idea is to trade multiplications for additions.

Here are some things to improve/fix:

  • [Probably wontfix] One minor annoyance is how to represent signs. Twos complement would be better than sign+magnitude for adding and subtracting, as these operations then become completely branchless, but then multiplications become more complex and ultimately I don't think you gain anything. Currently a whole extra limb is used for the sign bit; this isn't ideal either, and one could steal a single low or high bit of the fraction limbs instead, but then you need some extra masking operations which might not be worth the trouble.

  • Right now additions and subtractions are inlined up to n = 8 limbs. For larger n inlining additions everywhere might be overkill, but we should investigate whether using some hardcoded addition functions for various lengths > 8 is significantly faster than mpn_add_n/mpn_sub_n.

  • For multiplying, it would suffice to compute less accurate high products with O(n) error instead ~2 ulp error most of the time, so it would be useful to have hardcoded and basecase assembly for these sloppier variants of mulhigh up to a few limbs.

  • As usual, dot products are the workhorse operation. For now, I have implemented _nfixed_dot_2, _nfixed_dot_3 and _nfixed_dot_4 in C using umul_ppmm / add_ssaaaa / add_sssaaaaaa where I break up the sums so that there are no long carry chains. Full assembly implementations interleaving the multiplications and additions optimally should be even faster though. For the smallest n one might want the whole dot product in assembly; for larger n some mpn_addmulhigh_n type functions might suffice.

  • One has to scale inputs to $(-\varepsilon,\varepsilon)$ in such a way that no intermediate result escapes $(-1,1)$. I have not rigorously proved that this scaling is correct in the algorithms implemented so far, which of course will be absolutely necessary for eventual use by arb etc. Ideally, the bound should also not just be proven, but also tight, so that one wastes no precision. In some cases, it would be more efficient to use an extra limb for the output to store carry-out into an integer part, but that is not implemented. Similarly: proving error bounds. Done, but not published.

  • The most important and also most difficult problem: how to convert non-uniform floating-point matrices/polynomials to fixed-point to achieve full entrywise accuracy. Currently this is done very pessimistically by increasing the precision, avoiding fixed-point completely when things are too poorly scaled (including when any zeros are present). It is clearly possible to do much more clever things.

@albinahlback
Copy link
Collaborator

  • Right now additions and subtractions are inlined up to n = 8 limbs. For larger n inlining additions everywhere might be overkill, but we should investigate whether using some hardcoded addition functions for various lengths > 8 is significantly faster than mpn_add_n/mpn_sub_n.

Some notes: GMP's tuning program can output cycles per limb for different lengths. To get a rough estimate on where it actually could be beneficial to hardcode, run their program and check what smaller lengths do not yield something that resembles the asymptotics.

  • For multiplying, it would suffice to compute less accurate high products with O(n) error instead ~2 ulp error most of the time, so it would be useful to have hardcoded and basecase assembly for these sloppier variants of mulhigh up to a few limbs.

Are you referring to less accurate version of mulhigh?

  • As usual, dot products are the workhorse operation. For now, I have implemented _nfixed_dot_2, _nfixed_dot_3 and _nfixed_dot_4 in C using umul_ppmm / add_ssaaaa / add_sssaaaaaa where I break up the sums so that there are no long carry chains. Full assembly implementations interleaving the multiplications and additions optimally should be even faster though. For the smallest n one might want the whole dot product in assembly; for larger n some mpn_addmulhigh_n type functions might suffice.

This is a very interesting idea. I believe it should be hardcoded, and my intuition is that we could gain a lot of performance here. However, I believe with such a routine, it is very important to think about (1) input data format and (2) algorithm/precision.

Great ideas!

@fredrik-johansson
Copy link
Collaborator Author

I just did a quick timing comparison (with TIMEIT_START/TIMEIT_STOP) for addition. Second column is time for mpn_add_n, third is relative time for the NN_ADD_n macros, and third is relative time for NN_ADD_n wrapped in a function.

     mpn_add_n    NN_ADD_n  (wrapped)
2    2.710e-09      0.115    0.502
3    2.940e-09      0.155    0.537
4    2.720e-09      0.288    0.511
5    3.180e-09      0.478    0.431
6    3.410e-09      0.669    0.534
7    3.640e-09      0.786    0.500
8    3.200e-09      0.878    0.713

This suggests that fixed-length addition code should win well beyond n = 8.

Something surprising though is that the function call performs better than the inlined macro for n = 5 and beyond.

I think the reason is that the compiler (GCC) isn't able to interleave the add instructions with load instructions, so it has to load everything into registers, then do all adds at once, then write from registers. Wrapping this mess in a function then perhaps results in better overall register allocation?

Further, for big n the instruction sequence with all loads before the adds clearly isn't optimal for the CPU to execute. You'd want to interleave loads with adds so that they can be executed in parallel.

In conclusion:

  • Fixed length mpn_add_n and mpn_sub_n functions should definitely help, perhaps up to n = 16 or so. Obviously, and fortunately, code size is much less of an issue for these functions than for mpn_mul_n.
  • The current macros should maybe be replaced with function calls for n >= 5. Or maybe with more cleverly implemented macros. Instead of taking all limbs as parameters, the macros could maybe take pointer arguments and then contain explicit load instructions.

@fredrik-johansson
Copy link
Collaborator Author

Are you referring to less accurate version of mulhigh?

Yes, I think 95% of the time for fixed-point arithmetic a sloppy mulhigh makes sense because we've already padded the inputs to an extra limb of precision.

@fredrik-johansson
Copy link
Collaborator Author

However, I believe with such a routine, it is very important to think about (1) input data format and (2) algorithm/precision.

Sure, the only question mark for me is whether the sign encoding used in the current code is the best choice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants