gccでSSEを使ったベクトル演算を行う

n次元空間上にある二点間のユークリッド距離を計算する方法を考える.なお,それぞれの点は,n個の要素を持つfloatの配列で表されるとする.

単純に考えると,以下のような関数で計算可能となる.

float
dist(const float *v1, const float *v2, int len)
{
        float dist = 0.0;
        int   i;

        for (i = 0; i < len; i++) {
                float d;

                d  = v1[i] - v2[i];
                d *= d;

                dist += d;
        }

        return dist;
}

コンパイルオプションに-O3をつけたときの,アセンブリコードは以下の通りとなる.ちなみに,gccのバージョンは4.2.1となる.

        .text
        .align 4,0x90
.globl _dist
_dist:
        pushl   %ebp
        movl    %esp, %ebp
        pushl   %esi
        pushl   %ebx
        subl    $4, %esp
        movl    8(%ebp), %esi
        movl    12(%ebp), %ecx
        movl    16(%ebp), %edx
        xorps   %xmm1, %xmm1
        testl   %edx, %edx
        jle     L4
        xorps   %xmm1, %xmm1
        xorl    %eax, %eax
        .align 4,0x90
L5:
        movss   (%esi,%eax,4), %xmm0  # v1[i] を xmm0 レジスタへ転送
        subss   (%ecx,%eax,4), %xmm0  # d = v1[i] - v2[i];
        mulss   %xmm0, %xmm0          # d *= d;
        addss   %xmm0, %xmm1          # dist += d;
        incl    %eax                  # i++;
        cmpl    %edx, %eax
        jne     L5
L4:
        movss   %xmm1, -12(%ebp)
        flds    -12(%ebp)
        addl    $4, %esp
        popl    %ebx
        popl    %esi
        leave
        ret
        .subsections_via_symbols

ラベル L5 がループの中身であるが,ループ中にmovss 命令や subss 命令が使われているのが分かる.これらの命令もSSE命令の一つであるが,1インストラクションで行う演算は一つのみである.


gccでは,attributeでvector型の型を定義でき,これを用いることでSSEによるベクトル演算が可能となる.上記のプログラムをSSEに対応したのが以下となる.

typedef float v4sf __attribute__ ((vector_size (16)));

typedef union {
        v4sf v;
        float f[4];
} f4vec;

float
dist(float *v1, float *v2, int len)
{
        float dist = 0.0;
        int   i;

        for (i = 0; i < len; i += 4) {
                f4vec  d;
                f4vec *f4v1, *f4v2;

                f4v1 = (f4vec*)&v1[i];
                f4v2 = (f4vec*)&v2[i];

                d.v  = f4v1->v - f4v2->v;
                d.v *= d.v;

                dist += d.f[0] + d.f[1] + d.f[2] + d.f[3];
        }

        return dist;
}


以下の文でベクトル演算用のベクトル型の定義を行う.SSEのレジスタは16バイトであり,単精度浮動小数値ならば,同士に4つの演算を同時に行える.倍精度浮動小数ならば2つの演算となる.

typedef float v4sf __attribute__ ((vector_size (16)));

int型のベクトル演算を行ないたい場合は以下のようにする.intは基本的に32ビットなため,4つの演算が同時に行える.

typedef int v4sf __attribute__ ((vector_size (16)));

おなじく,-O3オプションを用いた時のアセンブリコードを見てみると以下のようなる.

        .text
        .align 4,0x90
.globl _dist
_dist:
        pushl   %ebp
        movl    %esp, %ebp
        pushl   %esi
        pushl   %ebx
        subl    $32, %esp
        movl    8(%ebp), %esi
        movl    12(%ebp), %ecx
        movl    16(%ebp), %edx
        xorps   %xmm1, %xmm1
        testl   %edx, %edx
        jle     L4
        xorps   %xmm1, %xmm1
        xorl    %eax, %eax
        .align 4,0x90
L5:
        movaps  (%esi,%eax,4), %xmm0  # v1[i], v1[i + 1], v1[i + 2], v1[i + 3] を
                                      # xmm0 レジスタへ転送
        subps   (%ecx,%eax,4), %xmm0  # v1[i]     - v2[i],
                                      # v1[i + 1] - v2[i + 1],
                                      # v1[i + 2] - v2[i + 2],
                                      # v1[i + 3] - v2[i + 3] を計算
        mulps   %xmm0, %xmm0          # d.v *= d.v; を計算.
                                      # すなわち,
                                      # d.f[0] *= d.f[0], d.f[1] *= d.f[1]
                                      # d.f[2] *= d.f[2], d.f[3] *= d.f[3]
        movaps  %xmm0, -24(%ebp)      # xmm0 レジスタ 128ビット文をメモリへ転送
        movss   -24(%ebp), %xmm0      # 32 ビットずつ加算
        addss   -20(%ebp), %xmm0
        addss   -16(%ebp), %xmm0
        addss   -12(%ebp), %xmm0
        addss   %xmm0, %xmm1          # dist に結果を加算
        addl    $4, %eax              # i += 4;
        cmpl    %eax, %edx
        jg      L5
L4:
        movss   %xmm1, -28(%ebp)
        flds    -28(%ebp)
        addl    $32, %esp
        popl    %ebx
        popl    %esi
        leave
        ret
        .subsections_via_symbols

アセンブラコード中に,SSEのベクトル演算命令である movaps, subps, movps が出ていることが確認できる.これらの命令は1インストラクションで複数の演算を行うことが可能である.

このプログラムを計測してみる.計測用のコードは以下のようになる.

#include <sys/time.h>
#include <stdio.h>
#include <stdlib.h>

#define DIM 128
#define LOOP 10000000

extern float dist(const float *v1, const float *v2, int len);

int
main(int argc, char *argv)
{
        struct timeval t1, t2;
        float  v1[DIM];
        float  v2[DIM];
        int    i, j;
        double diff1, diff2;

        srand48(time(NULL));

        gettimeofday(&t1, NULL);
        
        for (i = 0; i < LOOP; i++) {
                for (j = 0; j < DIM; j++) {
                        v1[j] = (float)drand48();
                        v2[j] = (float)drand48();
                }

                dist(v1, v2, DIM);
        }

        gettimeofday(&t2, NULL);

        diff1  = (t2.tv_sec + t2.tv_usec / 1000000.0);
        diff1 -= (t1.tv_sec + t1.tv_usec / 1000000.0);


        gettimeofday(&t1, NULL);
        
        for (i = 0; i < LOOP; i++) {
                for (j = 0; j < DIM; j++) {
                        v1[j] = (float)drand48();
                        v2[j] = (float)drand48();
                }
        }

        gettimeofday(&t2, NULL);

        diff2  = (t2.tv_sec + t2.tv_usec / 1000000.0);
        diff2 -= (t1.tv_sec + t1.tv_usec / 1000000.0);


        printf("elapsed time = %f\n", diff1 - diff2);

        return 0;
}

非SSEが1.851255秒となり,SSE対応バージョンが5.781228秒となった.

SSE対応バージョンの方が遅くなっている.これはおそらく以下のメモリアクセス部分がボトルネックになっていると考えられる.

        movaps  %xmm0, -24(%ebp)      # xmm0 レジスタ 128ビット分をメモリへ転送
        movss   -24(%ebp), %xmm0      # 32 ビットずつ加算
        addss   -20(%ebp), %xmm0
        addss   -16(%ebp), %xmm0
        addss   -12(%ebp), %xmm0

SSEには8個の演算レジスタがあるため,これをフルに活用するようプログラムを変更してみる.変更後のプログラムは以下のようになる.

typedef float v4sf __attribute__ ((vector_size (16)));

typedef union {
        v4sf v;
        float f[4];
} f4vec;

float
dist(float *v1, float *v2, int len)
{
        float dist = 0.0;
        f4vec d, d0, d1, d2, d3, d4, d5, d6, d7;
        int   i;

        for (i = 0; i < len; i += 32) {
#define vectorize(IDX)                                          \
                do {                                            \
                        f4vec *f4v1, *f4v2;                     \
                        int j = i + IDX * 4;                    \
                                                                \
                        f4v1 = (f4vec*)&v1[j];                  \
                        f4v2 = (f4vec*)&v2[j];                  \
                                                                \
                        d##IDX.v  = f4v1->v - f4v2->v;          \
                        d##IDX.v *= d##IDX.v;                   \
                } while (0);

                vectorize(0);
                vectorize(1);
                vectorize(2);
                vectorize(3);
                vectorize(4);
                vectorize(5);
                vectorize(6);
                vectorize(7);

                d.v = d0.v + d1.v + d2.v + d3.v + d4.v + d5.v + d6.v + d7.v;

                dist += d.f[0] + d.f[1] + d.f[2] + d.f[3];
        }

        return dist;
}

このソースのアセンブリコードは次のようになる.

	.text
	.align 4,0x90
.globl _dist
_dist:
	pushl	%ebp
	movl	%esp, %ebp
	pushl	%esi
	pushl	%ebx
	subl	$32, %esp
	movl	8(%ebp), %ecx
	movl	12(%ebp), %edx
	movl	16(%ebp), %esi
	testl	%esi, %esi
	jle	L9
	xorps	%xmm0, %xmm0
	movss	%xmm0, -28(%ebp)
	xorl	%eax, %eax
	.align 4,0x90
L5:
	movaps	(%ecx,%eax,4), %xmm6
	subps	(%edx,%eax,4), %xmm6
	movaps	16(%ecx,%eax,4), %xmm7
	subps	16(%edx,%eax,4), %xmm7
	movaps	32(%ecx,%eax,4), %xmm5
	subps	32(%edx,%eax,4), %xmm5
	movaps	48(%ecx,%eax,4), %xmm4
	subps	48(%edx,%eax,4), %xmm4
	movaps	64(%ecx,%eax,4), %xmm3
	subps	64(%edx,%eax,4), %xmm3
	movaps	80(%ecx,%eax,4), %xmm2
	subps	80(%edx,%eax,4), %xmm2
	movaps	96(%ecx,%eax,4), %xmm1
	subps	96(%edx,%eax,4), %xmm1
	movaps	112(%ecx,%eax,4), %xmm0
	subps	112(%edx,%eax,4), %xmm0
	mulps	%xmm0, %xmm0
	mulps	%xmm1, %xmm1
	mulps	%xmm2, %xmm2
	mulps	%xmm3, %xmm3
	mulps	%xmm4, %xmm4
	mulps	%xmm5, %xmm5
	mulps	%xmm6, %xmm6
	mulps	%xmm7, %xmm7
	addps	%xmm7, %xmm6
	addps	%xmm6, %xmm5
	addps	%xmm5, %xmm4
	addps	%xmm4, %xmm3
	addps	%xmm3, %xmm2
	addps	%xmm2, %xmm1
	addps	%xmm1, %xmm0
	movaps	%xmm0, -24(%ebp)
	movss	-24(%ebp), %xmm0
	addss	-20(%ebp), %xmm0
	addss	-16(%ebp), %xmm0
	addss	-12(%ebp), %xmm0
	addss	-28(%ebp), %xmm0
	movss	%xmm0, -28(%ebp)
	addl	$32, %eax
	cmpl	%eax, %esi
	jg	L5
	flds	-28(%ebp)
	addl	$32, %esp
	popl	%ebx
	popl	%esi
	leave
	ret
L9:
	xorps	%xmm0, %xmm0
	movss	%xmm0, -28(%ebp)
	flds	-28(%ebp)
	addl	$32, %esp
	popl	%ebx
	popl	%esi
	leave
	ret
	.subsections_via_symbols

xmm7までのレジスタが利用されているのが分かる.同じようにこのプログラムを計測した時の結果は1.170874秒となり,SSE非対応バージョンよりも,37%程度速度が向上した.もっと速くなると思ったが,たいして速くならなくてがっかりだよ!