#include "CudaVersionCheck.h"
#include "LpvStructs.h"
#include "LpvUtils.h"


#include "../Defines.h"
#include "cutil_math.h"
#include <math_constants.h>

__global__ void deviceMergeResetGeometry(float4* source, float4* target, int3 sz) {
	int idx = blockIdx.x * blockDim.x + threadIdx.x;
	int idy = blockIdx.y * blockDim.y + threadIdx.y;
	int idz = blockIdx.z * blockDim.z + threadIdx.z;
	int3 id = make_int3(idx, idy, idz);
	
	if(id.x < sz.x && id.y < sz.y && id.z < sz.z) {
		int index = id.z*sz.x*sz.y + id.y*sz.x + id.x;
		
		float4& src = source[index];
		float4& trg = target[index];
		
		if(dot(trg, trg) > dot(src, src)) {
			src = trg;
		}
		trg = make_float4(0.0f);
	}
}

void cudaMergeResetGeometry(void* source, void* target, int size) {
	
	dim3 dimBlock(BLOCKSIZE, BLOCKSIZE, BLOCKSIZE);
	dim3 dimGrid(size/BLOCKSIZE+1, size/BLOCKSIZE+1, size/BLOCKSIZE+1);
	
	deviceMergeResetGeometry<<<dimGrid, dimBlock>>>((float4*)source, (float4*)target, make_int3(size));
}
