2009年5月23日星期六

串列矩陣乘法優化

埋頭苦幹於 Squirrel 與其捆綁庫已有一個月多,偶爾間找一些好玩的來調劑一下,那就是有關 SIMD 的優化了。先呈上源碼:

// simdMatrixMul.h
#include <xmmintrin.h>

#if USE_ALIGNED
# define _MM_LOAD _mm_load_ps
# define _MM_STORE _mm_store_ps
# define ASSERT_ALIGNED(p) assert(intptr_t(p) % 16 == 0)
#else
# define _MM_LOAD _mm_loadu_ps
# define _MM_STORE _mm_storeu_ps
# define ASSERT_ALIGNED(p)
#endif

/*! SIMD chained matrix multiplication
Aliasing is allowed for the parameters
\param result Pointer to 16 floats as the return value
\param mats An array of pointer that point to a set of 16 floats
\param count Number of pointer in mats

Example:
\code
Matrix44 m1, m2, m3;
Matrix44 result;

// Perform result = m1 * m2 * m3;
float* mats[] = { (flaot*)(&m1), (float*)(&m2), (float*)(&m3) };
simdChainedMatrixMul((float*)&result, mats, 3);
\endcode
*/
void simdChainedMatrixMul(float* result, const float* mats[], size_t count)
{
assert(count >= 2);
ASSERT_ALIGNED(result);

__m128 ret[4];
__m128 x0, x1, x2, x3;

// Load the first matrix into ret
ret[0] = _MM_LOAD(mats[0] + 0*4);
ret[1] = _MM_LOAD(mats[0] + 1*4);
ret[2] = _MM_LOAD(mats[0] + 2*4);
ret[3] = _MM_LOAD(mats[0] + 3*4);

for(size_t j=1; j<count; ++j) {
ASSERT_ALIGNED(mats[j]);

// Prefetch the next matrix, may not usefull at all, test it case by case.
_mm_prefetch(reinterpret_cast<const char*>(&mats[j+1]), _MM_HINT_NTA);

__m128 x4 = _MM_LOAD(mats[j] + 0*4);
__m128 x5 = _MM_LOAD(mats[j] + 1*4);
__m128 x6 = _MM_LOAD(mats[j] + 2*4);
__m128 x7 = _MM_LOAD(mats[j] + 3*4);

for(size_t i=0; i<4; ++i) {
x1 = x2 = x3 = x0 = ret[i];
x0 = _mm_shuffle_ps(x0, x0, _MM_SHUFFLE(0,0,0,0));
x1 = _mm_shuffle_ps(x1, x1, _MM_SHUFFLE(1,1,1,1));
x2 = _mm_shuffle_ps(x2, x2, _MM_SHUFFLE(2,2,2,2));
x3 = _mm_shuffle_ps(x3, x3, _MM_SHUFFLE(3,3,3,3));

x0 = _mm_mul_ps(x0, x4);
x1 = _mm_mul_ps(x1, x5);
x2 = _mm_mul_ps(x2, x6);
x3 = _mm_mul_ps(x3, x7);

x2 = _mm_add_ps(x2, x0);
x3 = _mm_add_ps(x3, x1);
x3 = _mm_add_ps(x3, x2);

ret[i] = x3;
}
}

_MM_STORE(result + 0*4, ret[0]);
_MM_STORE(result + 1*4, ret[1]);
_MM_STORE(result + 2*4, ret[2]);
_MM_STORE(result + 3*4, ret[3]);
}
這一函數專門計算串列矩陣乘法,非常適合用於 transform traversal 或 skeleton animation 等等。因為所有相乘後的臨時結果都存儲在 SSE 寄存器內,沒有多餘的記憶體移動指令被浪費丟。兩項優化的結果使得 simdChainedMatrixMul 比一般矩陣乘法快三倍,以下是一個 Entity traversal 的測試程式:

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

// Benchmarking options
#define USE_ALIGNED 1
#define USE_SIMD 1

#include "simdChainedMatrixMul.h"

_MM_ALIGN16 class Matrix44
{
public:
void randomize()
{
for(int i=0; i<4; ++i) for(int j=0; j<4; ++j)
data[i][j] = 0.96f * ((float(rand()) / RAND_MAX) * 2 - 1);
}

float* operator[](size_t i) { return data[i]; }
const float* operator[](size_t i) const { return data[i]; }

// No aliasing is allowed for the 2 parameters
void mul(Matrix44* __restrict result, const Matrix44* __restrict rhs)
{
assert(result != rhs);
assert(result != this);

/* for(size_t i=0; i<4; ++i) for(size_t j=0; j<4; ++j) {
float sum = 0;
for(size_t k=0; k<4; ++k)
sum += data[i][k] * rhs->data[k][j];
result->data[i][j] = sum;
}
return;*/

result->m00 = m00 * rhs->m00 + m01 * rhs->m10 + m02 * rhs->m20 + m03 * rhs->m30;
result->m01 = m00 * rhs->m01 + m01 * rhs->m11 + m02 * rhs->m21 + m03 * rhs->m31;
result->m02 = m00 * rhs->m02 + m01 * rhs->m12 + m02 * rhs->m22 + m03 * rhs->m32;
result->m03 = m00 * rhs->m03 + m01 * rhs->m13 + m02 * rhs->m23 + m03 * rhs->m33;

result->m10 = m10 * rhs->m00 + m11 * rhs->m10 + m12 * rhs->m20 + m13 * rhs->m30;
result->m11 = m10 * rhs->m01 + m11 * rhs->m11 + m12 * rhs->m21 + m13 * rhs->m31;
result->m12 = m10 * rhs->m02 + m11 * rhs->m12 + m12 * rhs->m22 + m13 * rhs->m32;
result->m13 = m10 * rhs->m03 + m11 * rhs->m13 + m12 * rhs->m23 + m13 * rhs->m33;

result->m20 = m20 * rhs->m00 + m21 * rhs->m10 + m22 * rhs->m20 + m23 * rhs->m30;
result->m21 = m20 * rhs->m01 + m21 * rhs->m11 + m22 * rhs->m21 + m23 * rhs->m31;
result->m22 = m20 * rhs->m02 + m21 * rhs->m12 + m22 * rhs->m22 + m23 * rhs->m32;
result->m23 = m20 * rhs->m03 + m21 * rhs->m13 + m22 * rhs->m23 + m23 * rhs->m33;

result->m30 = m30 * rhs->m00 + m31 * rhs->m10 + m32 * rhs->m20 + m33 * rhs->m30;
result->m31 = m30 * rhs->m01 + m31 * rhs->m11 + m32 * rhs->m21 + m33 * rhs->m31;
result->m32 = m30 * rhs->m02 + m31 * rhs->m12 + m32 * rhs->m22 + m33 * rhs->m32;
result->m33 = m30 * rhs->m03 + m31 * rhs->m13 + m32 * rhs->m23 + m33 * rhs->m33;
}

union {
// Individual elements
struct { float
m00, m01, m02, m03,
m10, m11, m12, m13,
m20, m21, m22, m23,
m30, m31, m32, m33;
};
// As a 2 dimension array
float data[4][4];
};
}; // Matrix44

//! A simple Entity with single list structure.
class Entity
{
public:
#if USE_ALIGNED
void* operator new(size_t size) { return _aligned_malloc(size, 16); }
void operator delete(void* p) { _aligned_free(p); }
#endif

Entity() : parent(NULL), child(NULL)
{
ASSERT_ALIGNED(&matrix);
matrix.randomize();
}
~Entity() { delete child; }

Entity* addChild(Entity* e)
{
assert(e->child == NULL);
assert(!child);
if(!child) {
this->child = e;
e->parent = this;
}
else {
e->child = child;
child->parent = e;
this->child = e;
e->parent = this;
}
return e;
};

Matrix44 calculateWorldMatrix1()
{
Matrix44 result = matrix;
Entity* e = this;
while((e = e->parent) != NULL) {
Matrix44 tmp(result);
tmp.mul(&result, &e->matrix);
}
return result;
}

Matrix44 calculateWorldMatrix2()
{
const float* matrixArray[1024] = {0};
size_t count = 0;
Entity* e = this;
do {
matrixArray[count++] = (float*)(&(e->matrix));
if(count >= 1024)
exit(-1);
} while((e = e->parent) != NULL);

Matrix44 result;
simdChainedMatrixMul((float*)(&result), matrixArray, count);

return result;
}

Matrix44 matrix;
Entity* parent, *child;
// char padding[1024*1024]; // To test the effect of cache miss
};

int main(int argc, char* argv[])
{
srand(123);
Entity* root = new Entity();

Entity* lastChild = root;
for(size_t i=1000; i--;)
lastChild = lastChild->addChild(new Entity());

for(size_t i=0; i<10; ++i) {
// Print out a summation variable to prevent compiler optimization
// that turns the whole benchmark into nothing
float sum = 0;

clock_t t1 = clock();

if(!USE_SIMD) for(size_t i=10000; i--;)
sum += lastChild->calculateWorldMatrix1().m00;
else for(size_t i=10000; i--;)
sum += lastChild->calculateWorldMatrix2().m00;

clock_t t2 = clock();

printf("%f, dummy:%i\n", float(t2 - t1) / CLOCKS_PER_SEC, sum);
}

delete root;
return 0;
}

沒有留言:

發佈留言