#include "LpvUtils.h"

#include "CudaVersionCheck.h"
#include "../Defines.h"
#include "cutil_math.h"
#include <math_constants.h>

__device__ float sign(float f) {
	if(f < 0.0f) {
		return -1.0f;
	} else if(f > 0.0f) {
		return 1.0f;
	} else {
		return 0.0f;
	}
}

__device__ float4 transform4x4(const float4& v, Transform transform) {
	float* t = transform.v;
	float4 t1 = make_float4(t[0], t[1], t[2], t[3]);
	float4 t2 = make_float4(t[4], t[5], t[6], t[7]);
	float4 t3 = make_float4(t[8], t[9], t[10], t[11]);
	float4 t4 = make_float4(t[12], t[13], t[14], t[15]);
	
	return make_float4(dot(v, t1), dot(v, t2), dot(v, t3), dot(v, t4));
}

__device__ float4 constructSH(float3 direction) {
	return make_float4(
			 0.282094792f,
			-0.4886025119f * direction.y,
			 0.4886025119f * direction.z,
			-0.4886025119f * direction.x
		);
}

#if 0
__device__ float4 constructSHClampedCosineLobeAroundDirection(float3 direction) {
	return make_float4(0.5f, -0.75f * direction.y, 0.75f * direction.z, -0.75f * direction.x);
}
#else
__device__ float4 constructSHClampedCosineLobeAroundDirection(float3 direction) {
	return make_float4(
		CUDART_PI_F * 0.282094792f,
		((2.0f*CUDART_PI_F)/3.0f) * -0.4886025119f * direction.y, 
		((2.0f*CUDART_PI_F)/3.0f) *  0.4886025119f * direction.z, 
		((2.0f*CUDART_PI_F)/3.0f) * -0.4886025119f * direction.x
	);
}
#endif

__device__ float4 rsmDepthExtractLightSpacePosition(texture<uchar4, cudaTextureType3D, cudaReadModeNormalizedFloat> rsm_depth, int3 size, float3 texCoord) {
	int3 i = make_int3(texCoord * make_float3(size));
	i = clamp(i, make_int3(0,0,0), size-1);
	float4 encodedDepth = tex3D(rsm_depth, i.x, i.y, i.z);
	float depth = 	EXTRACT_DEPTH(encodedDepth) * RSM_CAMERA_FAR;
			//500.0f;
	
	float s = 2.0f * texCoord.x - 1.0f;
	float t = 2.0f * texCoord.y - 1.0f;
	
	float3 ret;
	switch(i.z) {
	case 0: // POSITIVE_X
		ret = make_float3(1.0f, t, s);
		break;
	case 1: // NEGATIVE_X
		ret = make_float3(-1.0f, t, -s);
		break;
	case 2: // POSITIVE_Y
		ret = make_float3(-s, 1.0f, -t);
		break;
	case 3: // NEGATIVE_Y
		ret = make_float3(-s, -1.0f, t);
		break;
	case 4: // POSITIVE_Z
		ret = make_float3(-s, t, 1.0f);
		break;
	case 5: // NEGATIVE_Z
		ret = make_float3(s, t, -1.0f);
		break;
	}
	
	return make_float4(normalize(ret) * depth, 1.0f);
}

__device__ void componentwiseAtomicFloat4Add(float4& addr, float4 val) {
#if 1
	atomicAdd(&addr.x, val.x);
	atomicAdd(&addr.y, val.y);
	atomicAdd(&addr.z, val.z);
	atomicAdd(&addr.w, val.w);
#else
	addr += val;
#endif
}

__device__ bool isInside(const int3& size, const int3& pos) {
	return pos.x >= 0 && pos.x < size.x && 
		pos.y >= 0 && pos.y < size.y && 
		pos.z >= 0 && pos.z < size.z;
}

__device__ int makeVolumeIndex(const int3& size, const int3& pos) {
	return pos.z*size.x*size.y + pos.y*size.x + pos.x;
}

__device__ float3 getFaceDirection(int face) {
	float3 ret = make_float3(0.0f, 0.0f, 1.0f);
	switch(face) {
	case 0: // POSITIVE_X
		ret = make_float3(1.0f, 0.0f, 0.0f);
		break;
	case 1: // NEGATIVE_X
		ret = make_float3(-1.0f, 0.0f, 0.0f);
		break;
	case 2: // POSITIVE_Y
		ret = make_float3(0.0f, 1.0f, 0.0f);
		break;
	case 3: // NEGATIVE_Y
		ret = make_float3(0.0f, -1.0f, 0.0f);
		break;
	case 4: // POSITIVE_Z
		ret = make_float3(0.0f, 0.0f, 1.0f);
		break;
	case 5: // NEGATIVE_Z
		ret = make_float3(0.0f, 0.0f, -1.0f);
		break;
	}
	return ret;
}

__device__ float deg2rad(float deg) {
	return deg * CUDART_PI_F / 180.0f;
}

__device__ float calculatePixelWeight(int3 ix, int3 sz) {
	float tanFovHalf = tan(0.5f * RSM_CAMERA_FOV * CUDART_PI_F / 180.0f);
	float nearPlaneHeight = RSM_CAMERA_NEAR * 2.0f * tanFovHalf;

	float texture_PixelWeight = 4.0f * RSM_CAMERA_ASPECT * pow(tanFovHalf, 2.0f) * pow(RSM_CAMERA_NEAR, 3.0f) / (sz.x * sz.y);

	float2 nearPlaneSize = make_float2(RSM_CAMERA_ASPECT * nearPlaneHeight, nearPlaneHeight);
	float2 texture_PixelSize = nearPlaneSize / make_float2(sz.x, sz.y);
	float2 nearPos = (make_float2(ix.x, ix.y) - 0.5f * make_float2(sz.x, sz.y) + 0.5f) * texture_PixelSize;

	float w = texture_PixelWeight / pow(RSM_CAMERA_NEAR * RSM_CAMERA_NEAR + dot(nearPos, nearPos), 1.5f);
	
	return w;
}

__device__ bool isZero(const float4& v) {
	return v.x + v.y + v.z + v.w == 0.0f;
}
