#include "LightPrerenderSetup.h"
#include "Defines.h"
#include "CallbackUtils.h"
#include "CalculateBoundingBoxVisitor.h"
#include "TimeMonitor.h"

#include "cuda/CommonStructs.h"

#include "data/LightPropagationVolume.h"
#include "data/ReflectiveShadowMap.h"
#include "data/GeometryVolume.h"
#include "data/IndexVolume.h"

#include "modules/GVClearModule.h"
#include "modules/GVInjectModule.h"
#include "modules/GVDownsampleModule.h"
#include "modules/LPVClearModule.h"
#include "modules/LPVInjectModule.h"
#include "modules/LPVDownsampleModule.h"
#include "modules/LPVPropagateModule.h"
#include "modules/LPVDebugModule.h"
#include "modules/LPVMergeModule.h"
#include "modules/GVDebugModule.h"
#include "modules/IXClearModule.h"

#include <osg/Matrixd>
#include <osg/Uniform>
#include <osg/MatrixTransform>
#include <osg/Program>
#include <osgCuda/Texture>
#include <osgCuda/Computation>
#include <osgCuda/Memory>
#include <osgCompute/Computation>

#include <iostream>
#include <cstring>

LightPrerenderSetup::LightPrerenderSetup(SceneConfig* config) : osg::NodeVisitor(NODE_VISITOR, TRAVERSE_ALL_CHILDREN), config(config) {}

void LightPrerenderSetup::apply(osg::LightSource &node) {
	lightSources.push_back(&node);
}

void LightPrerenderSetup::setup(osg::Node* scene, osg::Group* root, osg::StateSet* ss) {

	osg::StateSet* rootSs = root->getOrCreateStateSet();

	osg::Matrixd projectionMatrix;
	projectionMatrix.makePerspective(RSM_CAMERA_FOV, RSM_CAMERA_ASPECT, RSM_CAMERA_NEAR, RSM_CAMERA_FAR);
	//projectionMatrix.makeFrustum(-1.0,1.0,-1.0,1.0,CAMERA_NEAR,CAMERA_FAR);

	//const osg::BoundingSphere& bounds = scene->getBound();
	CalculateBoundingBoxVisitor cbbv;
	scene->accept(cbbv);
	const osg::BoundingBox& bb = cbbv.getBoundBox();

	/*
	 * Define uniforms that provide the render-to-screen pass shader with information about the light sources and
	 * the pre-render stages.
	 */
	osg::Uniform* lightViewUniform = new osg::Uniform(osg::Uniform::FLOAT_MAT4, "light_ModelViewMatrices", lightSources.size());
	osg::Uniform* lightProjectionUniform = new osg::Uniform("light_ProjectionMatrix", projectionMatrix); // This one is actually not used in the shader...

	osg::Uniform* lightDepthMapUniform = new osg::Uniform(osg::Uniform::SAMPLER_3D, "light_DepthMaps", lightSources.size());
	osg::Uniform* lightColorMapUniform = new osg::Uniform(osg::Uniform::SAMPLER_3D, "light_ColorMaps", lightSources.size());
	osg::Uniform* lightNormalMapUniform = new osg::Uniform(osg::Uniform::SAMPLER_3D, "light_NormalMaps", lightSources.size());

	osg::Uniform* lightPositionUniform = new osg::Uniform(osg::Uniform::FLOAT_VEC4, "light_Positions", lightSources.size());
	osg::Uniform* lightColorUniform = new osg::Uniform(osg::Uniform::FLOAT_VEC4, "light_Colors", lightSources.size());

	rootSs->addUniform(lightViewUniform);
	rootSs->addUniform(lightProjectionUniform);

	rootSs->addUniform(lightDepthMapUniform);
	rootSs->addUniform(lightColorMapUniform);
	rootSs->addUniform(lightNormalMapUniform);

	rootSs->addUniform(lightPositionUniform);
	rootSs->addUniform(lightColorUniform);

	rootSs->addUniform(new osg::Uniform("light_FarPlane", RSM_CAMERA_FAR));
	rootSs->addUniform(new osg::Uniform("light_NearPlane", RSM_CAMERA_NEAR));


	osg::Uniform* lightMapUniforms[] = {lightDepthMapUniform, lightColorMapUniform, lightNormalMapUniform};




	// Prepare the scene graph nodes for the CUDA computation chain.
	osg::ref_ptr<osgCompute::Computation> clearComputation = new osgCuda::Computation;
	osg::ref_ptr<osgCompute::Computation> injectComputation = new osgCuda::Computation;
	osg::ref_ptr<osgCompute::Computation> debugInjectComputation = new osgCuda::Computation;
	osg::ref_ptr<osgCompute::Computation> downsampleComputation = new osgCuda::Computation;
	osg::ref_ptr<osgCompute::Computation> propagateComputation = new osgCuda::Computation;
	osg::ref_ptr<osgCompute::Computation> debugPropagateComputation = new osgCuda::Computation;
	osg::ref_ptr<osgCompute::Computation> mergeComputation = new osgCuda::Computation;
	{
		clearComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);
		injectComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);
		debugInjectComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);
		downsampleComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);
		propagateComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);
		mergeComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);
		debugPropagateComputation->setComputeOrder(osgCompute::Computation::PRERENDER_AFTERCHILDREN);

		root->addChild(debugPropagateComputation);
		debugPropagateComputation->addChild(mergeComputation);
		mergeComputation->addChild(propagateComputation);
		propagateComputation->addChild(downsampleComputation);
		downsampleComputation->addChild(debugInjectComputation);
		debugInjectComputation->addChild(injectComputation);
		injectComputation->addChild(clearComputation);

	}

	// Prepare the pre-CUDA root node.
	osg::ref_ptr<osg::Group> prerenderRoot = new osg::Group;
	{
		clearComputation->addChild(prerenderRoot);
	}


	// Create the actual LPV and hook it up to the appropriate shader uniforms.
	lpv = new data::LightPropagationVolume(LPV_SIZE, "lpv");
	{
		osg::Uniform* lpvUniform = new osg::Uniform(osg::Uniform::SAMPLER_3D, "volume_LightPropagation", 3);

		data::LPVComponents* buffer = lpv->getAccumulationBuffer();

		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+0, buffer->getRed()->getTexture(), osg::StateAttribute::ON);
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+1, buffer->getGreen()->getTexture(), osg::StateAttribute::ON);
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+2, buffer->getBlue()->getTexture(), osg::StateAttribute::ON);
		lpvUniform->setElement(0, TEXTURE_ID_LPV_START+0);
		lpvUniform->setElement(1, TEXTURE_ID_LPV_START+1);
		lpvUniform->setElement(2, TEXTURE_ID_LPV_START+2);

		rootSs->addUniform(lpvUniform);
	}
	{
		osg::Uniform* debugUniform = new osg::Uniform(osg::Uniform::SAMPLER_3D, "volume_Debug", 3);
		data::LPVComponents* debugBuffer = lpv->getDebugBuffer();

		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+3, debugBuffer->getRed()->getTexture(), osg::StateAttribute::ON);
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+4, debugBuffer->getGreen()->getTexture(), osg::StateAttribute::ON);
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+5, debugBuffer->getBlue()->getTexture(), osg::StateAttribute::ON);
		debugUniform->setElement(0, TEXTURE_ID_LPV_START+3);
		debugUniform->setElement(1, TEXTURE_ID_LPV_START+4);
		debugUniform->setElement(2, TEXTURE_ID_LPV_START+5);

		rootSs->addUniform(debugUniform);


		osg::Uniform* encodedDebugUniform = new osg::Uniform(osg::Uniform::SAMPLER_3D, "volume_EncodedDebug", 3);
		data::LPVComponents* encodedDebugBuffer = lpv->getEncodedDebugBuffer();

		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+6, encodedDebugBuffer->getRed()->getTexture(), osg::StateAttribute::ON);
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+7, encodedDebugBuffer->getGreen()->getTexture(), osg::StateAttribute::ON);
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_LPV_START+8, encodedDebugBuffer->getBlue()->getTexture(), osg::StateAttribute::ON);
		encodedDebugUniform->setElement(0, TEXTURE_ID_LPV_START+6);
		encodedDebugUniform->setElement(1, TEXTURE_ID_LPV_START+7);
		encodedDebugUniform->setElement(2, TEXTURE_ID_LPV_START+8);

		rootSs->addUniform(encodedDebugUniform);
	}


	// Create the geometry volume and hook it up to the appropriate shader uniforms.
	data::GeometryVolume* gv = new data::GeometryVolume(LPV_SIZE, "gv");
	{
		rootSs->setTextureAttributeAndModes(TEXTURE_ID_GV, gv->getSource()->getTexture(), osg::StateAttribute::ON);
		rootSs->addUniform(new osg::Uniform("volume_Geometry", TEXTURE_ID_GV));


		rootSs->addUniform(new osg::Uniform("volume_Min", bb._min));
		rootSs->addUniform(new osg::Uniform("volume_Max", bb._max));
		rootSs->addUniform(new osg::Uniform("volume_Size", LPV_SIZE));
	}

	// Create the global index volume for the octree structure.
	data::IndexVolume* ix = new data::IndexVolume(LPV_SIZE, "ix");

	// Add all of the computation modules that behave as singletons.
	{
		// Add the LPV propagation module.
		{
			modules::LPVPropagateModule* propagateModule = new modules::LPVPropagateModule;
			propagateModule->setConfig(config);
			propagateModule->setLightPropagationVolume(lpv);
			propagateModule->setGeometryVolume(gv);
			propagateModule->setIndexVolume(ix);
			propagateComputation->addModule(*propagateModule);
		}

		// Add the LPV propagation debug module.
		{
			modules::LPVDebugModule* debugPropagateModule = new modules::LPVDebugModule;
			debugPropagateModule->setConfig(config);
			debugPropagateModule->setLightPropagationVolume(lpv);
			debugPropagateModule->setDebugMode(SceneConfig::DEBUG_PROPAGATION);
			debugPropagateModule->setTargetBuffer(data::LightPropagationVolume::ACCUMULATION);
			debugPropagateComputation->addModule(*debugPropagateModule);
		}

		// Add the LPV downsample module
		{
			modules::LPVDownsampleModule* lpvDownsampleModule = new modules::LPVDownsampleModule;
			lpvDownsampleModule->setConfig(config);
			lpvDownsampleModule->setLightPropagationVolume(lpv);
			lpvDownsampleModule->setIndexVolume(ix);
			downsampleComputation->addModule(*lpvDownsampleModule);
		}

		// Add the GV downsample module
		{
			modules::GVDownsampleModule* gvDownsampleModule = new modules::GVDownsampleModule;
			gvDownsampleModule->setConfig(config);
			gvDownsampleModule->setGeometryVolume(gv);
			gvDownsampleModule->setIndexVolume(ix);
			downsampleComputation->addModule(*gvDownsampleModule);
		}

		// Add the LPV injection debug module.
		{
			modules::LPVDebugModule* debugInjectionModule = new modules::LPVDebugModule;
			debugInjectionModule->setConfig(config);
			debugInjectionModule->setLightPropagationVolume(lpv);
			debugInjectionModule->setDebugMode(SceneConfig::DEBUG_INJECTION);
			debugInjectionModule->setTargetBuffer(data::LightPropagationVolume::TARGET);
			debugInjectComputation->addModule(*debugInjectionModule);
		}

		// Add the LPV merge module.
		{
			modules::LPVMergeModule* lpvMergeModule = new modules::LPVMergeModule;
			lpvMergeModule->setConfig(config);
			lpvMergeModule->setLightPropagationVolume(lpv);
			lpvMergeModule->setIndexVolume(ix);
			mergeComputation->addModule(*lpvMergeModule);
		}

		// Add the GV injection debug module.
		{
			modules::GVDebugModule* debugGeometryModule = new modules::GVDebugModule;
			debugGeometryModule->setConfig(config);
			debugGeometryModule->setLightPropagationVolume(lpv);
			debugGeometryModule->setGeometryVolume(gv);
			debugInjectComputation->addModule(*debugGeometryModule);
		}

		// Add the clear modules.
		{
			modules::LPVClearModule* lpvClearModule = new modules::LPVClearModule;
			lpvClearModule->setConfig(config);
			lpvClearModule->setLightPropagationVolume(lpv);
			clearComputation->addModule(*lpvClearModule);


			modules::GVClearModule* gvClearModule = new modules::GVClearModule;
			gvClearModule->setConfig(config);
			gvClearModule->setGeometryVolume(gv);
			clearComputation->addModule(*gvClearModule);

			modules::IXClearModule* ixClearModule = new modules::IXClearModule;
			ixClearModule->setConfig(config);
			ixClearModule->setIndexVolume(ix);
			ixClearModule->setInitialValue(lpv->getLevelCount()-1);
			clearComputation->addModule(*ixClearModule);
		}

	}



	/*
	 * Construct the common path of the scene graph that all RTT cameras will be have
	 * attached to them. Also set up the RTT shader program.
	 */
	osg::ref_ptr<osg::Group> prerenderScene = new osg::Group;
	prerenderScene->addChild(scene);
	{
		osg::ref_ptr<osg::Program> shaderProgram = new osg::Program;

		osg::ref_ptr<osg::Shader> vs = new osg::Shader(osg::Shader::VERTEX);
		osg::ref_ptr<osg::Shader> fs = new osg::Shader(osg::Shader::FRAGMENT);

		vs->loadShaderSourceFromFile("media/LightDepthMap.vs");
		fs->loadShaderSourceFromFile("media/LightDepthMap.fs");
		shaderProgram->addShader(vs);
		shaderProgram->addShader(fs);

		osg::StateSet* ss = prerenderScene->getOrCreateStateSet();
		ss->setAttributeAndModes(
				shaderProgram, osg::StateAttribute::ON);
		ss->addUniform(new osg::Uniform("camera_FarPlane", RSM_CAMERA_FAR));
		ss->addUniform(new osg::Uniform("camera_NearPlane", RSM_CAMERA_NEAR));
		ss->addUniform(new osg::Uniform("camera_Resolution", osg::Vec2(RSM_RESOLUTION, RSM_RESOLUTION)));
		ss->addUniform(new osg::Uniform("camera_FieldOfView", RSM_CAMERA_FOV));
		ss->addUniform(new osg::Uniform("camera_AspectRatio", RSM_CAMERA_ASPECT));
		ss->addUniform(lightPositionUniform);
		ss->addUniform(lightColorUniform);
	}

	/*
	 * Iterate over the light sources and set up the RTT cameras required for each one.
	 */
	int i = 0;
	std::vector<osg::LightSource*>::iterator it;
	for(it = lightSources.begin(); it != lightSources.end(); ++it) {
		osg::LightSource &node = **it;

		/*
		 * Register the automatic per-frame callbacks for updating the uniforms defined in the beginning.
		 */
		node.addCullCallback(new CopyLightPositionToUniformCallback(lightPositionUniform, i));
		node.addCullCallback(new CopyLightColorToUniformCallback(lightColorUniform, i));
		node.addCullCallback(new CopyLightViewToUniformCallback(lightViewUniform, i));

		/*
		 * The root node for the pre-render stage(s) involved for ONE light source.
		 */
		osg::ref_ptr<osg::Group> lightSourceRoot = new osg::Group;
		prerenderRoot->addChild(lightSourceRoot);

		/*
		 * Create and initialize the shadow cubemap texture.
		 */
		data::ReflectiveShadowMap* rsm = new data::ReflectiveShadowMap(RSM_RESOLUTION, RSM_SAMPLE_RESOLUTION, (*it)->getName() + "_rsm");
		for(int componentId = 0; componentId < 3; componentId++) {
			data::ReflectiveShadowMap::Component component = static_cast<data::ReflectiveShadowMap::Component>(componentId);
			/*
			 * Bind the texture to a free texture id and attach it to the light_DepthMaps uniform array.
			 */
			ss->setTextureAttributeAndModes(TEXTURE_ID_RSM_START+componentId*3+i, rsm->getTexture(component), osg::StateAttribute::ON);
			lightMapUniforms[componentId]->setElement(i, TEXTURE_ID_RSM_START+componentId*3+i);
		}


		/*
		 * Set up shared data for the LPV and GV injection modules.
		 */
		Transform* transform = new Transform;
		osgCompute::Memory* boundsMemory = new osgCuda::Memory;
		{
			node.addCullCallback(new CopyLightViewToCudaMemoryCallback(transform));

			boundsMemory->addIdentifier("bounding_box");
			boundsMemory->setElementSize(sizeof(float));
			boundsMemory->setDimension(0, 6);
			float* boundsHost = (float*)boundsMemory->map(osgCompute::MAP_HOST_TARGET);
			memcpy(boundsHost, bb._min._v, 3*sizeof(float));
			memcpy(boundsHost + 3, bb._max._v, 3*sizeof(float));
		}


		/*
		 * Set up the LPV injection module for this light source.
		 */
		{
			modules::LPVInjectModule* lpvInjectModule = new modules::LPVInjectModule;
			lpvInjectModule->setConfig(config);
			lpvInjectModule->setLightPropagationVolume(lpv);
			lpvInjectModule->setReflectiveShadowMap(rsm);
			lpvInjectModule->setIndexVolume(ix);

			lpvInjectModule->setTransform(transform);
			lpvInjectModule->acceptResource(*boundsMemory);
			injectComputation->addModule(*lpvInjectModule);
		}


		/*
		 * Set up the GV injection module for this light source.
		 */
		{
			modules::GVInjectModule* gvInjectModule = new modules::GVInjectModule;
			gvInjectModule->setConfig(config);
			gvInjectModule->setGeometryVolume(gv);
			gvInjectModule->setReflectiveShadowMap(rsm);
			gvInjectModule->setIndexVolume(ix);

			gvInjectModule->setTransform(transform);
			gvInjectModule->acceptResource(*boundsMemory);
			injectComputation->addModule(*gvInjectModule);
		}

		/*
		 * Add an actual RTT camera for each one of the six faces around the light source, each
		 * one aligned with a global axis in a certain direction.
		 */
		for(int face = 0; face < 6; face++) {
			osg::ref_ptr<osg::Camera> prerenderCamera = new osg::Camera;

			/*
			 * Our encoding implies that white is really really really far away. Thus we set
			 * the background to white.
			 */
			prerenderCamera->setClearColor(osg::Vec4(0.0f, 0.0f, 0.0f, 0.0f));
			prerenderCamera->setClearMask(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
			prerenderCamera->setProjectionMatrix(projectionMatrix);


			prerenderCamera->setPreDrawCallback(new BeforePrerenderCallback);
			prerenderCamera->setPostDrawCallback(new AfterPrerenderCallback);

			/*
			 * The RTT camera should be independent of transforms from higher levels in the scene graph.
			 */
			prerenderCamera->setReferenceFrame(osg::Transform::ABSOLUTE_RF);

			prerenderCamera->setViewport(0, 0, RSM_RESOLUTION, RSM_RESOLUTION);

			/*
			 * Ensure that this camera renders before the main camera.
			 */
			prerenderCamera->setRenderOrder(osg::Camera::PRE_RENDER);

			/*
			 * This camera should render to a FBO rather then to the screen.
			 */
			prerenderCamera->setRenderTargetImplementation(osg::Camera::FRAME_BUFFER_OBJECT);

			/*
			 * Specify which texture/FBO that the camera should render to.
			 */
			for(int componentId = 0; componentId < 3; componentId++) {
				data::ReflectiveShadowMap::Component component = static_cast<data::ReflectiveShadowMap::Component>(componentId);
				osg::Camera::BufferComponent buffer = static_cast<osg::Camera::BufferComponent>(osg::Camera::COLOR_BUFFER0+componentId);
				prerenderCamera->attach(buffer, rsm->getTexture(component), 0, face, false);
			}

			osg::StateSet* ssCam = prerenderCamera->getOrCreateStateSet();
			ssCam->addUniform(new osg::Uniform("light_Id", i));

			/*
			 * Add the scene that each RTT camera should render. In this case the sponza scene, but
			 * we refer to a lower level node then the "realScene" node used in the main function
			 * since we have specified our own shader and state settings that we don't want to have
			 * overwritten.
			 */
			prerenderCamera->addChild(prerenderScene);

			/*
			 * Add this RTT camera to this light source's pre-render root.
			 */
			lightSourceRoot->addChild(prerenderCamera);

			/*
			 * Ensure that this camera will be automatically moved whenever the corresponding light source is moved.
			 */
			node.addCullCallback(new UpdateCameraCallback(prerenderCamera, face));
		}

		++i;
		std::cout << "Successfully installed light source " << i << std::endl;
	}
}

data::LightPropagationVolume* LightPrerenderSetup::getLPV() const {
	return lpv;
}
