-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[WIP]gml-hexagon: Q4_0 mm opt #17907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| } | ||
|
|
||
| HVX_Vector_x4 r_dd = | ||
| hvx_vec_load_and_mul_d_r2x2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, y_d + i * y_dblk_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimized the scale multiplication step. The previous implementation only processed 32xf16 elements (half the vector width). This change enables 64xf16 multiplication to fully utilize the HVX vector capacity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting garbled output for all models.
Also, ultimately we end up with the INT32 accumulator for each block (32 elements).
In order to multiply it with the FP16 scale we need to convert both (accumulator and scale) into FP32 (QF32). This means that we still need to do the same number of multiplies and use the same number of HVX registers either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, ultimately we end up with the INT32 accumulator for each block (32 elements).
In order to multiply it with the FP16 scale we need to convert both (accumulator and scale) into FP32 (QF32).
-
Regarding the scales utilization: The original source uses 2
Q6_Wqf32_vmpy_VhfVhfinstructions for 2 rows but ignores the upper half. This PR aims to fully utilize the results of both multiplications. -
As for the accumulator width: For
Q4_0, an INT32 accumulator is likely excessive. Sincesrc0(4-bit) *src1(8-bit) fits in 12 bits, accumulating 32 elements only requires 17 bits total. A 32-bit accumulator is far larger than what is strictly required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, ultimately we end up with the INT32 accumulator for each block (32 elements).
In order to multiply it with the FP16 scale we need to convert both (accumulator and scale) into FP32 (QF32).
- Regarding the scales utilization: The original source uses 2
Q6_Wqf32_vmpy_VhfVhfinstructions for 2 rows but ignores the upper half. This PR aims to fully utilize the results of both multiplications.
Ah. Cool. I missed that. That part should help. Reading again.... :)
- As for the accumulator width: For
Q4_0, an INT32 accumulator is likely excessive. Sincesrc0(4-bit) *src1(8-bit) fits in 12 bits, accumulating 32 elements only requires 17 bits total. A 32-bit accumulator is far larger than what is strictly required.
That'd be relevant only if we had native INT17 data type and instructions to use it efficiently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That'd be relevant only if we had native INT17 data type and instructions to use it efficiently.
That's true regarding INT17.
But I'm now thinking we could potentially reduce the precision of src1 to INT7. If that works, we could keep the rest of the calculation within the f16 space. I'm not certain if it will be viable yet, and it would be a significant refactor, so I'll need some time to investigate the details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick update from my testing. I think most of the gain you're seeing from this PR comes from slightly wider processing of the blocks and not from the scales.
There is a slightly better (simpler) way to multiple scales from both rows.
I dug up my older code that does this
HVX_Vector vyy_d = Q6_Vh_vshuff_Vh(Q6_V_valign_VVR(vy_d, Q6_V_vror_VR(vy_d, 64), 64));
HVX_Vector r01_d = Q6_Vh_vshuff_Vh(Q6_V_valign_VVR(r1_d, Q6_V_vror_VR(r0_d, 64), 64));
HVX_VectorPair r01_dd = Q6_Wqf32_vmpy_VhfVhf(r01_d, vyy_d);
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r01_dd));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r01_dd));
Both this and the vmux based version you have in the PR do not by themselves improve things.
Looks like we either run out of registers or the extra instructions for mixing the scales are eating away the gains.
Now, the overall gains from this PR are not very consistent with some regressions:
## Before:
Gen3
common_perf_print: prompt eval time = 1923.41 ms / 205 tokens ( 9.38 ms per token, 106.58 tokens per second)
common_perf_print: eval time = 1429.96 ms / 63 runs (22.70 ms per token, 44.06 tokens per second)
Gen4
common_perf_print: prompt eval time = 1235.03 ms / 205 tokens ( 6.02 ms per token, 165.99 tokens per second)
common_perf_print: eval time = 1073.28 ms / 63 runs (17.04 ms per token, 58.70 tokens per second)
Gen5
common_perf_print: prompt eval time = 864.09 ms / 205 tokens ( 4.22 ms per token, 237.24 tokens per second)
common_perf_print: eval time = 1089.50 ms / 63 runs (17.29 ms per token, 57.82 tokens per second)
## After:
Gen3
common_perf_print: prompt eval time = 1773.68 ms / 205 tokens ( 8.65 ms per token, 115.58 tokens per second)
common_perf_print: eval time = 1373.27 ms / 63 runs (21.80 ms per token, 45.88 tokens per second)
Gen4
common_perf_print: prompt eval time = 1273.77 ms / 205 tokens ( 6.21 ms per token, 160.94 tokens per second)
common_perf_print: eval time = 1097.05 ms / 63 runs (17.41 ms per token, 57.43 tokens per second)
Gen5
common_perf_print: prompt eval time = 845.55 ms / 205 tokens ( 4.12 ms per token, 242.44 tokens per second)
common_perf_print: eval time = 1133.26 ms / 63 runs (17.99 ms per token, 55.59 tokens per second)
So a mix of gains and regressions.
I'm thinking we'll be better off doing a new version of the multiply reduce that avoids the reductions (ie shuffles and adds).
This can be done by interleaving groups of 4ints across 8-blocks in the repack and dyn quant. With that we can also use Vw_vrmpyacc_VwVbVb version which is a fused multiple accumulate.
I mentioned that I've been working on that version so maybe give me a few more days to clean that up and we can play with that version. It should provide bigger & consistent gains.
btw Might be good to do a clean rebase of this PR and squash/remove commits we no longer need (ie ROPE fixes, etc).
|
@max-krasnyansky, I'd like to open a discussion regarding Since the DMA engine can run in parallel with the HVX SIMD unit, I propose implementing a VTCM double-buffering strategy. This would allow us to overlap DMA loading with the |
Actually the DMA is fully asynchronous and it already overlaps with vec_dot.
You get the idea. It's fully pipelined. Typically all the waits are no-ops except for the first one. The Prompt on the other hand is compute bound and I'm working on redoing the matvec to optimize out the number of reductions that are needed (ie those rmpy_x8 functions can be improved but need a data layout/repack changes). |
Thanks. I was referring to swapping the order so we issue the DMA request (step 4) before |
Yep, I understood the suggestion, and I would recommend to re-read my description again :) btw I experimented with 32 and 64 rows scratchpad and it doesn't really help with the current HVX implementation. |
Yeah, your're right. |
Changes
hvx_vec_load_and_mul_d_rx2andhvx_vec_load_and_mul_d_r2x2helper functions to streamline vector loading and multiplication.vec_dot_q4x4x2_q8x4x2_rx2andvec_dot_q8x4x2_q8x4x2_rx2to improve instruction pipelining and reduce overhead in the main loops.Performance
The following performance comparison shows significant improvements for
MUL_MAT(type_a=q4_0, type_b=f32)across various batch sizes (n), with ~30% speedup observed forn >= 2.Device: 8Gen3
Baseline:
4d3726278Current:
00d5fb31bq4_0)n=2n=3n=4n=5n=8