This is an assignment written in C++ for Computer Graphics class. The goal of this assignment is to implement ray tracing for a triangle mesh, and implement acceleration structures to make the computation faster.
The program is based on provided code and the following are our specific tasks and results I had.
Here we try to implement an acceleration structure to speed-up the rendering of an image via ray-tracing. When compiling in release mode, one will typically achieve a speedup of 200x to 500x compared to the version without any acceleration structure.
The data structure we will study in this exercise in the AABB Tree, which is a special type of Bounding Volume Hierarchy (BVH), where each node of the tree is associated to the bounding box of its content. More specifically, each leaf nodes contains a single triangle, and stores a box which is the bounding box of this single triangle. Each internal node will have two children (for the sake of simplicity), and stores a box which corresponds to the union of the two boxes stored in its child nodes.
We propose to implement two different approaches to build an AABB Tree for triangles.
Top-Down Construction
In this approach, the input triangles are split into groups of roughly equal size starting from the top node. This can be described as a recursive function that does the following:
1-Split the input set of triangles into two sets S1 and S2.
2-Recursively build the subtree T1 corresponding to S1.
3-Recursively build the subtree T2 corresponding to S2.
4-Update the box of the current node by merging boxes of the root of T1 and T2.
5-Return the new root node R.
Sorting Criteria
We need a criteria to split the set of input triangles S into two subsets S1 and S2. We propose to simply sort the input triangles based on the coordinate of their centroid along the longest axis of the box spanned by those centroids. Then, S1 will hold the left half (rounded up), and S2 will hold the right half (rounded down).
Bottom-Up Construction
In this approach, we seek to pair nodes of the tree iteratively until only one remains: the root of the tree. Efficient methods to build a BVH with a bottom-up approach are more difficult to implement, since the require intelligent spatial sorting of the input data. We will settle for simple quadratic algorithm, where at each step, we will merge two nodes based on a certain criteria. This procedure can be summarized as follows:
1-Create a leaf node for every input triangle: N1, N2, ... Nk. Let S = {N1, N2, ..., Nk}.
2-Merge two nodes Ni and Nj that minimize some criteria f(Ni, Nj).
3-Update S accordingly (remove Ni, Nj from S, and add merge(Ni, Nj) to S).
4-Repeat until |S| == 1.
Cost Function
To evaluate which pair of node we should merge first, one can try one of the following criteria:
1. Take f(Ni, Nj) as the distance between the centroid of the boxes associated to Ni and Nj.
2. Take f(Ni, Nj) as the increase of volume in the new box (i.e. volume of the union minus volume of Ni and volume of Nj).
Tasks
1. Implement the intersection test of a ray and an axis-aligned bounding box. This test should be very simple. In particular, there should be no need to solve any line
2. Implement the top-down or the bottom-up construction method described above.
3. Update the intersection code of a ray and a mesh to use this newly created BVH. In particular, now we should only need to test the intersection for leaf nodes whose bounding box also intersects the input ray.
Result Image: It's not visible from the image but this .json scene was rendered in less time
////////////////////////////////////////////////////////////////////////////////
// C++ include
#include <fstream>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include <stack>
#include <deque>
// Eigen for matrix operations
#include <Eigen/Dense>
#include <Eigen/Geometry>
// Image writing library
#define STB_IMAGE_WRITE_IMPLEMENTATION // Do not include this line twice in your project!
#include "stb_image_write.h"
#include "utils.h"
// JSON parser library (https://github.com/nlohmann/json)
#include "json.hpp"
using json = nlohmann::json;
// Shortcut to avoid Eigen:: everywhere, DO NOT USE IN .h
using namespace Eigen;
////////////////////////////////////////////////////////////////////////////////
// Define types & classes
////////////////////////////////////////////////////////////////////////////////
struct Ray {
    Vector3d origin;
    Vector3d direction;
    Ray() { }
    Ray(Vector3d o, Vector3d d) : origin(o), direction(d) { }
};
struct Light {
    Vector3d position;
    Vector3d intensity;
};
struct Intersection {
    Vector3d position;
    Vector3d normal;
    double ray_param;
};
struct Camera {
    bool is_perspective;
    Vector3d position;
    double field_of_view; // between 0 and PI
    double focal_length;
    double lens_radius; // for depth of field
};
struct Material {
    Vector3d ambient_color;
    Vector3d diffuse_color;
    Vector3d specular_color;
    double specular_exponent; // Also called "shininess"
    Vector3d reflection_color;
    Vector3d refraction_color;
    double refraction_index;
};
struct Object {
    Material material;
    virtual ~Object() = default; // Classes with virtual methods should have a virtual destructor!
    virtual bool intersect(const Ray &ray, Intersection &hit) = 0;
};
// We use smart pointers to hold objects as this is a virtual class
typedef std::shared_ptr<Object> ObjectPtr;
struct Sphere : public Object {
    Vector3d position;
    double radius;
    virtual ~Sphere() = default;
    virtual bool intersect(const Ray &ray, Intersection &hit) override;
};
struct Parallelogram : public Object {
    Vector3d origin;
    Vector3d u;
    Vector3d v;
    virtual ~Parallelogram() = default;
    virtual bool intersect(const Ray &ray, Intersection &hit) override;
};
struct AABBTree {
    struct Node {
        AlignedBox3d bbox;
        int parent; // Index of the parent node (-1 for root)
        int left; // Index of the left child (-1 for a leaf)
        int right; // Index of the right child (-1 for a leaf)
        int triangle; // Index of the node triangle (-1 for internal nodes)
    };
    std::vector<Node> nodes;
    int root;
    AABBTree() = default; // Default empty constructor
    AABBTree(const MatrixXd &V, const MatrixXi &F); // Build a BVH from an existing mesh
    int TopDownSplit(const MatrixXd &V, MatrixXi F, MatrixXd C, int *indexArray);
    int BottomUpSplit(const MatrixXd &C, std::deque<int> &S);
    void calcCent(const Node &node, const MatrixXd &C, MatrixXd &cent, int &num);
};
struct Mesh : public Object {
    MatrixXd vertices; // n x 3 matrix (n points)
    MatrixXi facets; // m x 3 matrix (m triangles)
    AABBTree bvh;
    Mesh() = default; // Default empty constructor
    Mesh(const std::string &filename);
    virtual ~Mesh() = default;
    virtual bool intersect(const Ray &ray, Intersection &hit) override;
};
struct Scene {
    Vector3d background_color;
    Vector3d ambient_light;
    Camera camera;
    std::vector<Material> materials;
    std::vector<Light> lights;
    std::vector<ObjectPtr> objects;
};
////////////////////////////////////////////////////////////////////////////////
// Read a triangle mesh from an off file
void load_off(const std::string &filename, MatrixXd &V, MatrixXi &F) {
    std::ifstream in(filename);
    std::string token;
    in >> token;
    int nv, nf, ne;
    in >> nv >> nf >> ne;
    V.resize(nv, 3);
    F.resize(nf, 3);
    for (int i = 0; i < nv; ++i) {
        in >> V(i, 0) >> V(i, 1) >> V(i, 2);
    }
    for (int i = 0; i < nf; ++i) {
        int s;
        in >> s >> F(i, 0) >> F(i, 1) >> F(i, 2);
        assert(s == 3);
    }
}
Mesh::Mesh(const std::string &filename) {
    // Load a mesh from a file (assuming this is a .off file), and create a bvh
    load_off(filename, vertices, facets);
    bvh = AABBTree(vertices, facets);
}
////////////////////////////////////////////////////////////////////////////////
// BVH Implementation
////////////////////////////////////////////////////////////////////////////////
// Bounding box of a triangle
AlignedBox3d bbox_triangle(const Vector3d &a, const Vector3d &b, const Vector3d &c) {
    AlignedBox3d box;
    box.extend(a);
    box.extend(b);
    box.extend(c);
    return box;
}
AABBTree::AABBTree(const MatrixXd &V, const MatrixXi &F) {
    // Compute the centroids of all the triangles in the input mesh
    MatrixXd centroids(F.rows(), V.cols());  // n*3
    centroids.setZero();
    for (int i = 0; i < F.rows(); ++i) {
        for (int k = 0; k < F.cols(); ++k) {
            centroids.row(i) += V.row(F(i, k));
        }
        centroids.row(i) /= F.cols();
    }
    //////////////////////////////////////////////////////////////////
    // Method (1): Top-down approach.
    // Split each set of primitives into 2 sets of roughly equal size,
    // based on sorting the centroids along one direction or another.
    int indexArray[F.rows()];
    for (int i = 0; i < F.rows(); ++i)
        indexArray[i] = i;
    root = TopDownSplit(V, F, centroids, indexArray);
    nodes[root].parent = -1;  // Store "parent" after parent node created 
    //////////////////////////////////////////////////////////////////
    // Method (2): Bottom-up approach.
    // Merge nodes 2 by 2, starting from the leaves of the forest, until only 1 tree is left.
    /*
    std::deque<int> S;
    for (int i = 0; i < F.rows(); ++i) {
        // first create nodes for all leaves (triangles)
        Node node;
        node.left = -1;
        node.right = -1;
        node.triangle = i;
        node.bbox = bbox_triangle(V.row(F(i, 0)), V.row(F(i, 1)), V.row(F(i, 2)));
        nodes.push_back(node);
        S.push_back(i);
    }
    root = BottomUpSplit(centroids, S);
    nodes[root].parent = -1;  // Store "parent" after parent node created 
    */
}
int AABBTree::TopDownSplit(const MatrixXd &V, MatrixXi F, MatrixXd C, int *indexArray) {
    // if recursion ends, 
    if (F.rows() == 1) {
        // this is a leaf node
        Node node;
        node.left = -1;
        node.right = -1;
        node.triangle = indexArray[0];
        node.bbox = bbox_triangle(V.row(F(0, 0)), V.row(F(0, 1)), V.row(F(0, 2)));
        nodes.push_back(node);
        return (nodes.size() - 1);
    }
    // find the index of the longest axis
    int longestAxisIndex = 0;
    for (int i = 1; i < C.cols(); ++i) {
        if (C.col(i).maxCoeff() - C.col(i).minCoeff() > C.col(longestAxisIndex).maxCoeff() - C.col(longestAxisIndex).minCoeff())
            longestAxisIndex = i;
    }
    // prepare left and right part container
    int leftSize = std::ceil(F.rows() / 2);
    MatrixXi Fl(leftSize, 3), Fr(F.rows() - leftSize, 3);
    MatrixXd Cl(leftSize, 3), Cr(F.rows() - leftSize, 3);
    int indexArrayl[leftSize], indexArrayr[F.rows() - leftSize];
    // sort for "leftSize" smallest ones
    int sortIndex[F.rows()];  // auxiliary index array for sort
    for (int i=0; i<F.rows(); ++i)
        sortIndex[i] = i;
    for (int i=0; i<leftSize; ++i)
        for (int j=i+1; j<F.rows(); ++j) {
            if (C(sortIndex[j], longestAxisIndex) < C(sortIndex[i], longestAxisIndex))
                std::swap(sortIndex[i], sortIndex[j]);
        }
    // split left and right parts
    // left
    for (int i=0; i<leftSize; ++i) {
        Fl.block(i, 0, 1, 3) = F.block(sortIndex[i], 0, 1, 3);
        Cl.block(i, 0, 1, 3) = C.block(sortIndex[i], 0, 1, 3);
        indexArrayl[i] = indexArray[sortIndex[i]];
    }
    // right
    for (int i=0; i<F.rows()-leftSize; ++i) {
        Fr.block(i, 0, 1, 3) = F.block(sortIndex[i + leftSize], 0, 1, 3);
        Cr.block(i, 0, 1, 3) = C.block(sortIndex[i + leftSize], 0, 1, 3);
        indexArrayr[i] = indexArray[sortIndex[i + leftSize]];
    }
    // recurse
    int leftIndex = TopDownSplit(V, Fl, Cl, indexArrayl);
    int rightIndex = TopDownSplit(V, Fr, Cr, indexArrayr);
    // store and return
    Node node;
    node.left = leftIndex;
    node.right = rightIndex;
    node.triangle = -1;
    AlignedBox3d bbox = nodes[leftIndex].bbox;
    bbox.extend(nodes[rightIndex].bbox);
    node.bbox = bbox;
    nodes.push_back(node);
    nodes[leftIndex].parent = nodes.size() - 1;  // Store "parent" after parent node created
    nodes[rightIndex].parent = nodes.size() - 1;
    return (nodes.size() - 1);
}
void AABBTree::calcCent(const Node &node, const MatrixXd &C, MatrixXd &cent, int &num) {
// Auxiliary function for "BottomUpSplit" to do recursion
    if (node.left == -1) {  // leaf
        cent = cent + C.row(node.triangle);
        num++;
        return;
    }
    // recurse
    calcCent(nodes[node.left], C, cent, num);
    calcCent(nodes[node.right], C, cent, num);
}
int AABBTree::BottomUpSplit(const MatrixXd &C, std::deque<int> &S) {
    while (S.size() > 1) {
        // calc the centroid of all nodes
        MatrixXd C_(S.size(), 3);
        C_.setZero();
        int count = 0;
        for (int i : S) {
            MatrixXd cent(1, 3);
            int num = 0;
            calcCent(nodes[i], C, cent, num);
            C_(count, 0) = cent(0, 0) / num;
            C_(count, 1) = cent(0, 1) / num;
            C_(count, 2) = cent(0, 2) / num;
            ++count;
        }
        // find the min pair
        double minDis = 10e10;
        int leftIndexS = -1, rightIndexS = -1;
        for (int i = 0; i<C_.rows()-1; ++i)
            for (int j = i+1; j<C_.rows(); ++j) {
                if ((C_.row(i) - C_.row(j)).squaredNorm() < minDis) {
                    minDis = (C_.row(i) - C_.row(j)).squaredNorm();
                    leftIndexS = i;
                    rightIndexS = j;
                }
            }
        // create new node
        Node node;
        node.left = S[leftIndexS];
        node.right = S[rightIndexS];
        node.triangle = -1;
        AlignedBox3d bbox = nodes[S[leftIndexS]].bbox;
        bbox.extend(nodes[S[rightIndexS]].bbox);
        node.bbox = bbox;
        nodes.push_back(node);
        nodes[S[leftIndexS]].parent = nodes.size() - 1;  // Store "parent" after parent node created
        nodes[S[rightIndexS]].parent = nodes.size() - 1;
        // update container S
        int temp = S[leftIndexS];
        S.erase(S.begin() + rightIndexS);
        for (auto it=S.begin(); it!=S.end(); ++it)
            if (*it == temp) {
                S.erase(it);
                break;
            }
        S.push_back(nodes.size() - 1);
    }
    // the root must be the last created node
    return nodes.size() - 1;
}
////////////////////////////////////////////////////////////////////////////////
bool Sphere::intersect(const Ray &ray, Intersection &hit) {
// Compute the intersection between the ray and the sphere
// If the ray hits the sphere, set the result of the intersection in the
// struct 'hit'
    double A = ray.direction.squaredNorm();
    double B = (ray.origin - position).dot(ray.direction) * 2;
    double C = (ray.origin - position).squaredNorm() - radius * radius;
    double delta = B * B - 4 * A * C;
    if (delta < 0) return false;
    double t1 = (-B - std::sqrt(delta)) / (2 * A);
    double t2 = (-B + std::sqrt(delta)) / (2 * A);
    if (t1 < 0 && t2 < 0) return false;
    t1 = (t1 < t2) ? t1 : t2;
    if (t1 < 0) t1 = t2;
    hit.position = ray.origin + ray.direction * t1;
    hit.normal = (hit.position - position).normalized();
    return true;
}
bool Parallelogram::intersect(const Ray &ray, Intersection &hit) {
    Vector3d y = ray.origin - origin;
    Matrix3d A;
    A.col(0) = u;
    A.col(1) = v;
    A.col(2) = -ray.direction;
    Vector3d x = A.colPivHouseholderQr().solve(y);
    if (!(x(0) >= 0 && x(0) <=1 && x(1) >= 0 && x(1) <= 1))
        return false;
    hit.position = ray.origin + x(2) * ray.direction;
    if ((ray.direction).dot(hit.position - ray.origin) < 0)
        return false;
    hit.normal = u.cross(v);
    if ((ray.origin - origin).dot(hit.normal) < 0)
        hit.normal = -1 * hit.normal;
    (hit.normal).normalize();
    return true;
}
// -----------------------------------------------------------------------------
bool intersect_triangle(const Ray &ray, const Vector3d &a, const Vector3d &b, const Vector3d &c, Intersection &hit) {
    // Modified from Parallelogram
    Vector3d origin = a;
    Vector3d u = b - a, v = c - a;
    // Rest almost identical
    Vector3d y = ray.origin - origin;
    Matrix3d A;
    A.col(0) = u;
    A.col(1) = v;
    A.col(2) = -ray.direction;
    Vector3d x = A.colPivHouseholderQr().solve(y);
    if (!(x(0) >= 0 && x(1) >= 0 && x(0) + x(1) <= 1))
        return false;
    hit.position = ray.origin + x(2) * ray.direction;
    if ((ray.direction).dot(hit.position - ray.origin) < 0)
        return false;
    hit.normal = u.cross(v);
    if ((ray.origin - origin).dot(hit.normal) < 0)
        hit.normal = -1 * hit.normal;
    (hit.normal).normalize();
    return true;
}
bool intersect_box(const Ray &ray, const AlignedBox3d &box) {
// Compute whether the ray intersects the given box.
// There is no need to set the resulting normal and ray parameter, since
// we are not testing with the real surface here anyway.
    // [ref: https://gamedev.stackexchange.com/a/18459]
    Vector3d min_ = box.min();
    Vector3d max_ = box.max();
    double  t1 = (min_(0) - ray.origin(0)) / ray.direction(0), 
            t2 = (max_(0) - ray.origin(0)) / ray.direction(0), 
            t3 = (min_(1) - ray.origin(1)) / ray.direction(1), 
            t4 = (max_(1) - ray.origin(1)) / ray.direction(1), 
            t5 = (min_(2) - ray.origin(2)) / ray.direction(2), 
            t6 = (max_(2) - ray.origin(2)) / ray.direction(2);
    
    double     tmin = std::max(std::max(std::min(t1, t2), std::min(t3, t4)), std::min(t5, t6)), 
            tmax = std::min(std::min(std::max(t1, t2), std::max(t3, t4)), std::max(t5, t6));
    if (tmax < 0 || tmin > tmax)
        return false;
    else
        return true;
}
bool TraverseAABB(const Ray &ray, const MatrixXd &V, const MatrixXi &F, const AABBTree &bvh, Intersection &closest_hit, int root) {
// Auxiliary function for mesh::intersection to do recursion
    // check leaf node for end of recursion
    if (bvh.nodes[root].left == -1) {
        int triIndex = bvh.nodes[root].triangle;
        return intersect_triangle(ray, V.row(F(triIndex, 0)), V.row(F(triIndex, 1)), V.row(F(triIndex, 2)), closest_hit);
    }
    // if this subtree has no intersection at all...
    if (!intersect_box(ray, bvh.nodes[root].bbox))
        return false;

    // prepare variables
    int leftIndex = bvh.nodes[root].left, rightIndex = bvh.nodes[root].right;
    Intersection leftIntersection, rightIntersection;
    // recurse
    bool leftHit = TraverseAABB(ray, V, F, bvh, leftIntersection, leftIndex);
    bool rightHit = TraverseAABB(ray, V, F, bvh, rightIntersection, rightIndex);
    // find the closest hit
    if (!leftHit && !rightHit) return false;  // no hit
    if (!leftHit && rightHit) closest_hit = rightIntersection;
    if (leftHit && !rightHit) closest_hit = leftIntersection;
    if (leftHit && rightHit) {
        if ((leftIntersection.position - ray.origin).squaredNorm() < (rightIntersection.position - ray.origin).squaredNorm())
            closest_hit = leftIntersection;
        else
            closest_hit = rightIntersection;
    }
    return true;  // have hit
}
bool Mesh::intersect(const Ray &ray, Intersection &closest_hit) {
    //////////////////////////////////////////////////////////////////
    // Method (1): Traverse every triangle and return the closest hit.
    /*
    Intersection hit;
    bool haveHit = false;
    for (int i=0; i<facets.rows(); ++i) {
    // for every triangle
        if (intersect_triangle(ray, vertices.row(facets(i, 0)), vertices.row(facets(i, 1)), vertices.row(facets(i, 2)), hit)) {
        // hit with i-th triangle
            if (haveHit) {
                if ((ray.origin - hit.position).squaredNorm() < (ray.origin - closest_hit.position).squaredNorm())
                    closest_hit = hit;
            } else {
                haveHit = true;
                closest_hit = hit;
            }
        }
    }
    */
    //////////////////////////////////////////////////////////////////
    // Method (2): Traverse the BVH tree and test the intersection with a
    // triangles at the leaf nodes that intersects the input ray.
    bool haveHit = TraverseAABB(ray, vertices, facets, bvh, closest_hit, bvh.root);
    //////////////////////////////////////////////////////////////////
    return haveHit;
}
////////////////////////////////////////////////////////////////////////////////
// Define ray-tracing functions
////////////////////////////////////////////////////////////////////////////////
// Function declaration here (could be put in a header file)
Vector3d ray_color(const Scene &scene, const Ray &ray, const Object &object, const Intersection &hit, int max_bounce);
Object * find_nearest_object(const Scene &scene, const Ray &ray, Intersection &closest_hit);
bool is_light_visible(const Scene &scene, const Ray &ray, const Light &light, const Vector3d hitPosition);
Vector3d shoot_ray(const Scene &scene, const Ray &ray, int max_bounce);
// -----------------------------------------------------------------------------
Vector3d ray_color(const Scene &scene, const Ray &ray, const Object &obj, const Intersection &hit, int max_bounce) {
    // Material for hit object
    const Material &mat = obj.material;
    // Ambient light contribution
    Vector3d ambient_color = obj.material.ambient_color.array() * scene.ambient_light.array();
    // Punctual lights contribution (direct lighting)
    Vector3d lights_color(0, 0, 0);
    for (const Light &light : scene.lights) {
        Vector3d Li = (light.position - hit.position).normalized();
        Vector3d N = hit.normal;
        // shadow rays
        Ray shadowRay(hit.position + 10e-7 * Li, Li);
        if (!is_light_visible(scene, shadowRay, light, hit.position)) continue;
        // Diffuse contribution
        Vector3d diffuse = mat.diffuse_color * std::max(Li.dot(N), 0.0);
        // Specular contribution
        Vector3d h = (Li - ray.direction).normalized();
        Vector3d specular = mat.specular_color * std::pow(std::max(N.dot(h), 0.0), mat.specular_exponent);
        // Attenuate lights according to the squared distance to the lights
        Vector3d D = light.position - hit.position;
        lights_color += (diffuse + specular).cwiseProduct(light.intensity) /  D.squaredNorm();
    }
    // (Assignment 2, reflected ray) Not implemented in this assignment
    Vector3d reflection_color(0, 0, 0);
    // (Assignment 2, refracted ray) Not implemented in this assignment
    Vector3d refraction_color(0, 0, 0);
    // Rendering equation
    Vector3d C = ambient_color + lights_color + reflection_color + refraction_color;
    return C;
}
// -----------------------------------------------------------------------------
Object * find_nearest_object(const Scene &scene, const Ray &ray, Intersection &closest_hit) {
    int closest_index = -1;
    Intersection hit;
    for (auto it=scene.objects.begin(); it!=scene.objects.end(); it++) {
        if ((*it)->intersect(ray, hit)) {
            // check whether closer
            if (closest_index == -1 || (hit.position - ray.origin).squaredNorm() < (closest_hit.position - ray.origin).squaredNorm()) {
                closest_index = it - scene.objects.begin();
                closest_hit = hit;
            }
        }
    }
    if (closest_index < 0) {
        // Return a NULL pointer
        return nullptr;
    } else {
        // Return a pointer to the hit object. Don't forget to set 'closest_hit' accordingly!
        return scene.objects[closest_index].get();
    }
}
bool is_light_visible(const Scene &scene, const Ray &ray, const Light &light, const Vector3d hitPosition) {
    // Shoot a shadow ray to determine if the light should affect the intersection point
    Intersection shadowHit;
    for (const ObjectPtr object : scene.objects) {
        if (object->intersect(ray, shadowHit) && (hitPosition - shadowHit.position).dot(light.position - shadowHit.position) < 0) {
            return false;
        }
    }
    return true;
}
Vector3d shoot_ray(const Scene &scene, const Ray &ray, int max_bounce) {
    Intersection hit;
    if (Object * obj = find_nearest_object(scene, ray, hit)) {
        // 'obj' is not null and points to the object of the scene hit by the ray
        return ray_color(scene, ray, *obj, hit, max_bounce);
    } else {
        // 'obj' is null, we must return the background color
        return scene.background_color;
    }
}
////////////////////////////////////////////////////////////////////////////////
void render_scene(const Scene &scene) {
    std::cout << "Simple ray tracer." << std::endl;
    int w = 640;
    int h = 480;
    MatrixXd R = MatrixXd::Zero(w, h);
    MatrixXd G = MatrixXd::Zero(w, h);
    MatrixXd B = MatrixXd::Zero(w, h);
    MatrixXd A = MatrixXd::Zero(w, h); // Store the alpha mask
    // The camera always points in the direction -z
    // The sensor grid is at a distance 'focal_length' from the camera center,
    // and covers an viewing angle given by 'field_of_view'.
    double aspect_ratio = double(w) / double(h);
    double scale_y = scene.camera.focal_length * std::tan(scene.camera.field_of_view / 2);  // Stretch the pixel grid by the proper amount here
    double scale_x = aspect_ratio * scale_y; //
    // The pixel grid through which we shoot rays is at a distance 'focal_length'
    // from the sensor, and is scaled from the canonical [-1,1] in order
    // to produce the target field of view.
    Vector3d grid_origin(scene.camera.position(0)-scale_x, scene.camera.position(1)+scale_y, scene.camera.position(2)-scene.camera.focal_length);
    Vector3d x_displacement(2.0/w*scale_x, 0, 0);
    Vector3d y_displacement(0, -2.0/h*scale_y, 0);
    // depth of field by default samples 5 points
    std::vector<Vector3d> apertureCorrection;
    apertureCorrection.push_back(Vector3d(0, 0, 0));
    apertureCorrection.push_back(Vector3d(scene.camera.lens_radius, 0, 0));
    apertureCorrection.push_back(Vector3d(0, scene.camera.lens_radius, 0));
    apertureCorrection.push_back(Vector3d(-scene.camera.lens_radius, 0, 0));
    apertureCorrection.push_back(Vector3d(0, -scene.camera.lens_radius, 0));
    int depthOfFieldSampleNumber = 1;  // Depth of field turned off for speed

    for (unsigned i = 0; i < w; ++i) {
        std::cout << std::fixed << std::setprecision(2);
        std::cout << "Ray tracing: " << (100.0 * i) / w << "%\r" << std::flush;
        for (unsigned j = 0; j < h; ++j) {
            for (unsigned k = 0; k < depthOfFieldSampleNumber; ++k) {
                Vector3d shift = grid_origin + (i+0.5)*x_displacement + (j+0.5)*y_displacement;
                // Prepare the ray
                Ray ray;
                if (scene.camera.is_perspective) {
                    // Perspective camera
                    ray.origin = scene.camera.position;
                    ray.direction = (shift - ray.origin).normalized();
                } else {
                    // Orthographic camera
                    ray.origin = scene.camera.position + Vector3d(shift[0], shift[1], 0);
                    ray.direction = Vector3d(0, 0, -1);
                }
                int max_bounce = 3;
                Vector3d C = shoot_ray(scene, ray, max_bounce);
                R(i, j) += C(0) / depthOfFieldSampleNumber;
                G(i, j) += C(1) / depthOfFieldSampleNumber;
                B(i, j) += C(2) / depthOfFieldSampleNumber;
                A(i, j) = 1;
            }
        }
    }
    std::cout << "Ray tracing: 100%  " << std::endl;
    // Save to png
    const std::string filename("raytrace.png");
    write_matrix_to_png(R, G, B, A, filename);
}
////////////////////////////////////////////////////////////////////////////////
Scene load_scene(const std::string &filename) {
    Scene scene;
    // Load json data from scene file
    json data;
    std::ifstream in(filename);
    in >> data;
    // Helper function to read a Vector3d from a json array
    auto read_vec3 = [] (const json &x) {
        return Vector3d(x[0], x[1], x[2]);
    };
    // Read scene info
    scene.background_color = read_vec3(data["Scene"]["Background"]);
    scene.ambient_light = read_vec3(data["Scene"]["Ambient"]);
    // Read camera info
    scene.camera.is_perspective = data["Camera"]["IsPerspective"];
    scene.camera.position = read_vec3(data["Camera"]["Position"]);
    scene.camera.field_of_view = data["Camera"]["FieldOfView"];
    scene.camera.focal_length = data["Camera"]["FocalLength"];
    scene.camera.lens_radius = data["Camera"]["LensRadius"];
    // Read materials
    for (const auto &entry : data["Materials"]) {
        Material mat;
        mat.ambient_color = read_vec3(entry["Ambient"]);
        mat.diffuse_color = read_vec3(entry["Diffuse"]);
        mat.specular_color = read_vec3(entry["Specular"]);
        mat.reflection_color = read_vec3(entry["Mirror"]);
        mat.refraction_color = read_vec3(entry["Refraction"]);
        mat.refraction_index = entry["RefractionIndex"];
        mat.specular_exponent = entry["Shininess"];
        scene.materials.push_back(mat);
    }
    // Read lights
    for (const auto &entry : data["Lights"]) {
        Light light;
        light.position = read_vec3(entry["Position"]);
        light.intensity = read_vec3(entry["Color"]);
        scene.lights.push_back(light);
    }
    // Read objects
    for (const auto &entry : data["Objects"]) {
        ObjectPtr object;
        if (entry["Type"] == "Sphere") {
            auto sphere = std::make_shared<Sphere>();
            sphere->position = read_vec3(entry["Position"]);
            sphere->radius = entry["Radius"];
            object = sphere;
        } else if (entry["Type"] == "Parallelogram") {
            auto parallelogram = std::make_shared<Parallelogram>();
            parallelogram->origin = read_vec3(entry["Origin"]);
            parallelogram->u = read_vec3(entry["U"]);
            parallelogram->v = read_vec3(entry["V"]);
            object = parallelogram;
        } else if (entry["Type"] == "Mesh") {
            // Load mesh from a file
            std::string filename = std::string("../data/") + entry["Path"].get<std::string>();
            object = std::make_shared<Mesh>(filename);
        }
        object->material = scene.materials[entry["Material"]];
        scene.objects.push_back(object);
    }
    return scene;
}
////////////////////////////////////////////////////////////////////////////////
int main(int argc, char *argv[]) {
    if (argc < 2) {
        std::cerr << "Usage: " << argv[0] << " scene.json" << std::endl;
        return 1;
    }
    Scene scene = load_scene(argv[1]);
    render_scene(scene);
    return 0;
}

Back to Top