gccでSSEを使ったベクトル演算を行う
n次元空間上にある二点間のユークリッド距離を計算する方法を考える.なお,それぞれの点は,n個の要素を持つfloatの配列で表されるとする.
単純に考えると,以下のような関数で計算可能となる.
float dist(const float *v1, const float *v2, int len) { float dist = 0.0; int i; for (i = 0; i < len; i++) { float d; d = v1[i] - v2[i]; d *= d; dist += d; } return dist; }
コンパイルオプションに-O3をつけたときの,アセンブリコードは以下の通りとなる.ちなみに,gccのバージョンは4.2.1となる.
.text .align 4,0x90 .globl _dist _dist: pushl %ebp movl %esp, %ebp pushl %esi pushl %ebx subl $4, %esp movl 8(%ebp), %esi movl 12(%ebp), %ecx movl 16(%ebp), %edx xorps %xmm1, %xmm1 testl %edx, %edx jle L4 xorps %xmm1, %xmm1 xorl %eax, %eax .align 4,0x90 L5: movss (%esi,%eax,4), %xmm0 # v1[i] を xmm0 レジスタへ転送 subss (%ecx,%eax,4), %xmm0 # d = v1[i] - v2[i]; mulss %xmm0, %xmm0 # d *= d; addss %xmm0, %xmm1 # dist += d; incl %eax # i++; cmpl %edx, %eax jne L5 L4: movss %xmm1, -12(%ebp) flds -12(%ebp) addl $4, %esp popl %ebx popl %esi leave ret .subsections_via_symbols
ラベル L5 がループの中身であるが,ループ中にmovss 命令や subss 命令が使われているのが分かる.これらの命令もSSE命令の一つであるが,1インストラクションで行う演算は一つのみである.
gccでは,attributeでvector型の型を定義でき,これを用いることでSSEによるベクトル演算が可能となる.上記のプログラムをSSEに対応したのが以下となる.
typedef float v4sf __attribute__ ((vector_size (16))); typedef union { v4sf v; float f[4]; } f4vec; float dist(float *v1, float *v2, int len) { float dist = 0.0; int i; for (i = 0; i < len; i += 4) { f4vec d; f4vec *f4v1, *f4v2; f4v1 = (f4vec*)&v1[i]; f4v2 = (f4vec*)&v2[i]; d.v = f4v1->v - f4v2->v; d.v *= d.v; dist += d.f[0] + d.f[1] + d.f[2] + d.f[3]; } return dist; }
以下の文でベクトル演算用のベクトル型の定義を行う.SSEのレジスタは16バイトであり,単精度浮動小数値ならば,同士に4つの演算を同時に行える.倍精度浮動小数ならば2つの演算となる.
typedef float v4sf __attribute__ ((vector_size (16)));
int型のベクトル演算を行ないたい場合は以下のようにする.intは基本的に32ビットなため,4つの演算が同時に行える.
typedef int v4sf __attribute__ ((vector_size (16)));
おなじく,-O3オプションを用いた時のアセンブリコードを見てみると以下のようなる.
.text .align 4,0x90 .globl _dist _dist: pushl %ebp movl %esp, %ebp pushl %esi pushl %ebx subl $32, %esp movl 8(%ebp), %esi movl 12(%ebp), %ecx movl 16(%ebp), %edx xorps %xmm1, %xmm1 testl %edx, %edx jle L4 xorps %xmm1, %xmm1 xorl %eax, %eax .align 4,0x90 L5: movaps (%esi,%eax,4), %xmm0 # v1[i], v1[i + 1], v1[i + 2], v1[i + 3] を # xmm0 レジスタへ転送 subps (%ecx,%eax,4), %xmm0 # v1[i] - v2[i], # v1[i + 1] - v2[i + 1], # v1[i + 2] - v2[i + 2], # v1[i + 3] - v2[i + 3] を計算 mulps %xmm0, %xmm0 # d.v *= d.v; を計算. # すなわち, # d.f[0] *= d.f[0], d.f[1] *= d.f[1] # d.f[2] *= d.f[2], d.f[3] *= d.f[3] movaps %xmm0, -24(%ebp) # xmm0 レジスタ 128ビット文をメモリへ転送 movss -24(%ebp), %xmm0 # 32 ビットずつ加算 addss -20(%ebp), %xmm0 addss -16(%ebp), %xmm0 addss -12(%ebp), %xmm0 addss %xmm0, %xmm1 # dist に結果を加算 addl $4, %eax # i += 4; cmpl %eax, %edx jg L5 L4: movss %xmm1, -28(%ebp) flds -28(%ebp) addl $32, %esp popl %ebx popl %esi leave ret .subsections_via_symbols
アセンブラコード中に,SSEのベクトル演算命令である movaps, subps, movps が出ていることが確認できる.これらの命令は1インストラクションで複数の演算を行うことが可能である.
このプログラムを計測してみる.計測用のコードは以下のようになる.
#include <sys/time.h> #include <stdio.h> #include <stdlib.h> #define DIM 128 #define LOOP 10000000 extern float dist(const float *v1, const float *v2, int len); int main(int argc, char *argv) { struct timeval t1, t2; float v1[DIM]; float v2[DIM]; int i, j; double diff1, diff2; srand48(time(NULL)); gettimeofday(&t1, NULL); for (i = 0; i < LOOP; i++) { for (j = 0; j < DIM; j++) { v1[j] = (float)drand48(); v2[j] = (float)drand48(); } dist(v1, v2, DIM); } gettimeofday(&t2, NULL); diff1 = (t2.tv_sec + t2.tv_usec / 1000000.0); diff1 -= (t1.tv_sec + t1.tv_usec / 1000000.0); gettimeofday(&t1, NULL); for (i = 0; i < LOOP; i++) { for (j = 0; j < DIM; j++) { v1[j] = (float)drand48(); v2[j] = (float)drand48(); } } gettimeofday(&t2, NULL); diff2 = (t2.tv_sec + t2.tv_usec / 1000000.0); diff2 -= (t1.tv_sec + t1.tv_usec / 1000000.0); printf("elapsed time = %f\n", diff1 - diff2); return 0; }
非SSEが1.851255秒となり,SSE対応バージョンが5.781228秒となった.
SSE対応バージョンの方が遅くなっている.これはおそらく以下のメモリアクセス部分がボトルネックになっていると考えられる.
movaps %xmm0, -24(%ebp) # xmm0 レジスタ 128ビット分をメモリへ転送 movss -24(%ebp), %xmm0 # 32 ビットずつ加算 addss -20(%ebp), %xmm0 addss -16(%ebp), %xmm0 addss -12(%ebp), %xmm0
SSEには8個の演算レジスタがあるため,これをフルに活用するようプログラムを変更してみる.変更後のプログラムは以下のようになる.
typedef float v4sf __attribute__ ((vector_size (16))); typedef union { v4sf v; float f[4]; } f4vec; float dist(float *v1, float *v2, int len) { float dist = 0.0; f4vec d, d0, d1, d2, d3, d4, d5, d6, d7; int i; for (i = 0; i < len; i += 32) { #define vectorize(IDX) \ do { \ f4vec *f4v1, *f4v2; \ int j = i + IDX * 4; \ \ f4v1 = (f4vec*)&v1[j]; \ f4v2 = (f4vec*)&v2[j]; \ \ d##IDX.v = f4v1->v - f4v2->v; \ d##IDX.v *= d##IDX.v; \ } while (0); vectorize(0); vectorize(1); vectorize(2); vectorize(3); vectorize(4); vectorize(5); vectorize(6); vectorize(7); d.v = d0.v + d1.v + d2.v + d3.v + d4.v + d5.v + d6.v + d7.v; dist += d.f[0] + d.f[1] + d.f[2] + d.f[3]; } return dist; }
このソースのアセンブリコードは次のようになる.
.text .align 4,0x90 .globl _dist _dist: pushl %ebp movl %esp, %ebp pushl %esi pushl %ebx subl $32, %esp movl 8(%ebp), %ecx movl 12(%ebp), %edx movl 16(%ebp), %esi testl %esi, %esi jle L9 xorps %xmm0, %xmm0 movss %xmm0, -28(%ebp) xorl %eax, %eax .align 4,0x90 L5: movaps (%ecx,%eax,4), %xmm6 subps (%edx,%eax,4), %xmm6 movaps 16(%ecx,%eax,4), %xmm7 subps 16(%edx,%eax,4), %xmm7 movaps 32(%ecx,%eax,4), %xmm5 subps 32(%edx,%eax,4), %xmm5 movaps 48(%ecx,%eax,4), %xmm4 subps 48(%edx,%eax,4), %xmm4 movaps 64(%ecx,%eax,4), %xmm3 subps 64(%edx,%eax,4), %xmm3 movaps 80(%ecx,%eax,4), %xmm2 subps 80(%edx,%eax,4), %xmm2 movaps 96(%ecx,%eax,4), %xmm1 subps 96(%edx,%eax,4), %xmm1 movaps 112(%ecx,%eax,4), %xmm0 subps 112(%edx,%eax,4), %xmm0 mulps %xmm0, %xmm0 mulps %xmm1, %xmm1 mulps %xmm2, %xmm2 mulps %xmm3, %xmm3 mulps %xmm4, %xmm4 mulps %xmm5, %xmm5 mulps %xmm6, %xmm6 mulps %xmm7, %xmm7 addps %xmm7, %xmm6 addps %xmm6, %xmm5 addps %xmm5, %xmm4 addps %xmm4, %xmm3 addps %xmm3, %xmm2 addps %xmm2, %xmm1 addps %xmm1, %xmm0 movaps %xmm0, -24(%ebp) movss -24(%ebp), %xmm0 addss -20(%ebp), %xmm0 addss -16(%ebp), %xmm0 addss -12(%ebp), %xmm0 addss -28(%ebp), %xmm0 movss %xmm0, -28(%ebp) addl $32, %eax cmpl %eax, %esi jg L5 flds -28(%ebp) addl $32, %esp popl %ebx popl %esi leave ret L9: xorps %xmm0, %xmm0 movss %xmm0, -28(%ebp) flds -28(%ebp) addl $32, %esp popl %ebx popl %esi leave ret .subsections_via_symbols
xmm7までのレジスタが利用されているのが分かる.同じようにこのプログラムを計測した時の結果は1.170874秒となり,SSE非対応バージョンよりも,37%程度速度が向上した.もっと速くなると思ったが,たいして速くならなくてがっかりだよ!