#include "CudaVersionCheck.h"
#include "LpvStructs.h"
#include "LpvUtils.h"


#include "../Defines.h"
#include "cutil_math.h"
#include <math_constants.h>

#if 1
__global__ void devicePropagateVolumes(LPVComponents source, LPVComponents target, LPVComponents accumulated, 
			float4* geometry, 
			char* ix, int size, int level, 
			float factor, 
			bool useOcclusion, float occlusionFactor, 
			bool first) {
	int idx = blockIdx.x * blockDim.x + threadIdx.x;
	int idy = blockIdx.y * blockDim.y + threadIdx.y;
	int idz = blockIdx.z * blockDim.z + threadIdx.z;
	
	int3& sz = target.size;
	
	if (idx < sz.x && idy < sz.y && idz < sz.z) {
		
		float4& redOut = 	target.red[idz*sz.y*sz.x + idy*sz.x + idx];
		float4& greenOut = 	target.green[idz*sz.y*sz.x + idy*sz.x + idx];
		float4& blueOut = 	target.blue[idz*sz.y*sz.x + idy*sz.x + idx];
		
		redOut = 	make_float4(0.0f);
		greenOut = 	make_float4(0.0f);
		blueOut = 	make_float4(0.0f);
		
		int3 offsets[] = {
			make_int3(1,0,0), make_int3(-1,0,0),
			make_int3(0,1,0), make_int3(0,-1,0),
			make_int3(0,0,1), make_int3(0,0,-1)
		};
		
		int3 gvOffsets[] = {
			make_int3(0,0,0), make_int3(0,0,-1),
			make_int3(0,-1,0), make_int3(0,-1,-1), 
			make_int3(-1,0,0), make_int3(-1,0,-1),
			make_int3(-1,-1,0), make_int3(-1,-1,-1)
		};
		
		int neightborOffsets[][4] = {
			{4,5,6,7},
			{0,1,2,3},
			{2,3,6,7},
			{0,1,4,5},
			{1,3,5,7},
			{0,2,4,6}
		};
		
		if(!useOcclusion) {
			occlusionFactor = 0.0f;
		}
		
		int3 pos = make_int3(idx, idy, idz);
		
		for(int neighbor = 0; neighbor < 6; neighbor++) {
			int3 offset = offsets[neighbor];
			
			int3 targetPos = pos + offset;
			
			if(isInside(source.size, targetPos)) {
				int targetIndex = makeVolumeIndex(source.size, targetPos);
			
				float4 redIn = 		source.red[targetIndex];
				float4 greenIn = 	source.green[targetIndex];
				float4 blueIn = 	source.blue[targetIndex];
				
				
				float4 geometrySH = make_float4(0.0f);
				if(useOcclusion) {
					int count = 0;
					for(int i = 0; i < 4; i++) {
						int3 gvPos = targetPos + gvOffsets[neightborOffsets[neighbor][i]];
						if(isInside(sz, gvPos)) {
							geometrySH += geometry[makeVolumeIndex(sz, gvPos)];
							count++;
						}
					}
					geometrySH /= count;
				}
				
			
				for(int face = 0; face < 6; face++) {
					float3 dir = 0.5f * make_float3(offsets[face]) - make_float3(offset);
					float len = length(dir);
					dir /= len;
		
					float solidAngle;
					if(len <= 0.5f) {
						solidAngle = 0.0f; 
					} else {
						solidAngle = (len >= 1.5f) ? 22.95668f/(4.0f*180.0f) : 24.26083f/(4.0f*180.0f);
					}
			
		
					float4 dirSH = constructSH(dir);
			
			
					float occlusion = 1.0f - max(0.0f, min(1.0f, occlusionFactor * dot(geometrySH , dirSH)));
			
					float3 flux = make_float3(
								occlusion * solidAngle * max(0.0f, dot(redIn, dirSH)),
								occlusion * solidAngle * max(0.0f, dot(greenIn, dirSH)),
								occlusion * solidAngle * max(0.0f, dot(blueIn, dirSH))
							);
		
					flux *= factor;
		
					float4 coeffs = constructSHClampedCosineLobeAroundDirection(make_float3(offsets[face]));
	
					redOut += flux.x * coeffs;
					greenOut += flux.y * coeffs;
					blueOut += flux.z * coeffs;
				}
			}
		}
		
		if(first) {
			redOut += source.red[idz*sz.y*sz.x + idy*sz.x + idx];
			greenOut += source.green[idz*sz.y*sz.x + idy*sz.x + idx];
			blueOut += source.blue[idz*sz.y*sz.x + idy*sz.x + idx];
			
			accumulated.red[idz*sz.y*sz.x + idy*sz.x + idx] = 	redOut;
			accumulated.green[idz*sz.y*sz.x + idy*sz.x + idx] = 	greenOut;
			accumulated.blue[idz*sz.y*sz.x + idy*sz.x + idx] = 	blueOut;
		} else {
			accumulated.red[idz*sz.y*sz.x + idy*sz.x + idx] += 	redOut;
			accumulated.green[idz*sz.y*sz.x + idy*sz.x + idx] += 	greenOut;
			accumulated.blue[idz*sz.y*sz.x + idy*sz.x + idx] += 	blueOut;
		}
		
		if(!isZero(redOut) || !isZero(greenOut) || !isZero(blueOut)) {
			char& oldValue = ix[idz*sz.y*sz.x + idy*sz.x + idx];
			oldValue = min(oldValue, level);
		}
	}
}
#else
__device__ float3 getPrimaryDirection(float4 sh) {
	float3 dir = make_float3(-sh.w, -sh.y, sh.z);
	if(dir.x == 0.0f && dir.y == 0.0f && dir.z == 0.0f) {
		return dir;
	} else {
		dir = normalize(dir);
		return dir;
	}
}

__global__ void devicePropagateVolumes(LPVComponents source, LPVComponents target, LPVComponents accumulated, 
			float4* geometry, 
			char* ix, int size, int level, 
			float factor, 
			bool useOcclusion, float occlusionFactor, 
			bool first) {
	
	int idx = blockIdx.x * blockDim.x + threadIdx.x;
	int idy = blockIdx.y * blockDim.y + threadIdx.y;
	int idz = blockIdx.z * blockDim.z + threadIdx.z;
	
	int3& sz = target.size;
	
	if (idx < sz.x && idy < sz.y && idz < sz.z) {
		
		float4& redOut = 	target.red[idz*sz.y*sz.x + idy*sz.x + idx];
		float4& greenOut = 	target.green[idz*sz.y*sz.x + idy*sz.x + idx];
		float4& blueOut = 	target.blue[idz*sz.y*sz.x + idy*sz.x + idx];
		
		redOut = 	make_float4(0.0f);
		greenOut = 	make_float4(0.0f);
		blueOut = 	make_float4(0.0f);
		
		int3 offsets[] = {
			make_int3(1,0,0), make_int3(-1,0,0),
			make_int3(0,1,0), make_int3(0,-1,0),
			make_int3(0,0,1), make_int3(0,0,-1)
		};
		
		int3 pos = make_int3(idx, idy, idz);
		
		for(int neighbor = 0; neighbor < 6; neighbor++) {
			int3 offset = offsets[neighbor];
			
			int3 targetPos = pos + offset;
			
			if(isInside(source.size, targetPos)) {
				int targetIndex = makeVolumeIndex(source.size, targetPos);
			
				float4 redIn = 		source.red[targetIndex];
				float4 greenIn = 	source.green[targetIndex];
				float4 blueIn = 	source.blue[targetIndex];
				
				float4 shIncomingDirFunction = constructSHClampedCosineLobeAroundDirection(make_float3(-offset.x, -offset.y, -offset.z));
				
				float3 incidentLuminance = make_float3(
					max(0.0f, dot(redIn, shIncomingDirFunction)), 
					max(0.0f, dot(greenIn, shIncomingDirFunction)), 
					max(0.0f, dot(blueIn, shIncomingDirFunction)));
				
				float4 shOutgoingRed = constructSHClampedCosineLobeAroundDirection(getPrimaryDirection(redIn));
				float4 shOutgoingGreen = constructSHClampedCosineLobeAroundDirection(getPrimaryDirection(greenIn));
				float4 shOutgoingBlue = constructSHClampedCosineLobeAroundDirection(getPrimaryDirection(blueIn));
				
				redOut += incidentLuminance.x * shOutgoingRed;
				greenOut += incidentLuminance.y * shOutgoingGreen;
				blueOut += incidentLuminance.z * shOutgoingBlue;
			}
		}
		
		redOut *= factor;
		greenOut *= factor;
		blueOut *= factor;
		
		if(first) {
			accumulated.red[idz*sz.y*sz.x + idy*sz.x + idx] += 	source.red[idz*sz.y*sz.x + idy*sz.x + idx];
			accumulated.green[idz*sz.y*sz.x + idy*sz.x + idx] += 	source.green[idz*sz.y*sz.x + idy*sz.x + idx];
			accumulated.blue[idz*sz.y*sz.x + idy*sz.x + idx] += 	source.blue[idz*sz.y*sz.x + idy*sz.x + idx];
		}
		
		accumulated.red[idz*sz.y*sz.x + idy*sz.x + idx] += 	redOut;
		accumulated.green[idz*sz.y*sz.x + idy*sz.x + idx] += 	greenOut;
		accumulated.blue[idz*sz.y*sz.x + idy*sz.x + idx] += 	blueOut;
		
		if(!isZero(redOut) || !isZero(greenOut) || !isZero(blueOut)) {
			int value = level;
			int3 p1 = make_int3(idx, idy, idz);
			int3 p2 = p1 + make_int3(1);
			for(int z = p1.z << value; z < p2.z << value; z++) {
				for(int y = p1.y << value; y < p2.y << value; y++) {
					for(int x = p1.x << value; x < p2.x << value; x++) {
						char& storedValue = ix[z*size*size + y*size + x];
						storedValue = min(storedValue, value);
					}
				}
			}
		}
										
	}
}
#endif


void cudaPropagateVolumes(void* sources[3], void* targets[3], void* accumulated[3], void* geometry, void* ix, int size, int level, float factor, bool useOcclusion, float occlusionFactor, bool first) {
	int3 sz = make_int3(size >> level);

	LPVComponents src;
	src.size = sz;
	src.red = (float4*)sources[0];
	src.green = (float4*)sources[1];
	src.blue = (float4*)sources[2];
	
	LPVComponents trg;
	trg.size = sz;
	trg.red = (float4*)targets[0];
	trg.green = (float4*)targets[1];
	trg.blue = (float4*)targets[2];
	
	LPVComponents sum;
	sum.size = sz;
	sum.red = (float4*)accumulated[0];
	sum.green = (float4*)accumulated[1];
	sum.blue = (float4*)accumulated[2];
	
	dim3 dimBlock(BLOCKSIZE, BLOCKSIZE, BLOCKSIZE);
	dim3 dimGrid(size/BLOCKSIZE+1, size/BLOCKSIZE+1, size/BLOCKSIZE+1);
	
	devicePropagateVolumes<<<dimGrid, dimBlock>>>(src, trg, sum, 
			(float4*)geometry, 
			(char*)ix, size, level, 
			factor, 
			useOcclusion, occlusionFactor, 
			first);
}
