Using More Option Enum
In Rust, it is common to use the Option<T>
enum to express a return type that either has an unwrappable value or does not have a value at all.
Let’s revisit the Hittable
trait:
pub trait Hittable: Send + Sync {
fn hit(&self, ray: &Ray, t_min: f64, t_max: f64, rec: &mut HitRecord) -> bool;
}
We are mutating HitRecord
only when we hit something, which is a perfect place to use Option<T>
enum.
Let’s refactor our code to use the Option<T>
enum:
use std::sync::Arc;
use crate::material::Material;
use crate::ray::Ray;
use crate::vec3::{self, Point3, Vec3};
#[derive(Clone, Default)]
pub struct HitRecord {
pub p: Point3,
pub normal: Vec3,
pub mat: Option<Arc<dyn Material>>,
pub t: f64,
pub front_face: bool,
}
impl HitRecord {
pub fn new() -> HitRecord {
Default::default()
}
pub fn set_face_normal(&mut self, r: &Ray, outward_normal: Vec3) {
self.front_face = vec3::dot(r.direction(), outward_normal) < 0.0;
self.normal = if self.front_face {
outward_normal
} else {
-outward_normal
};
}
}
pub trait Hittable: Send + Sync {
fn hit(&self, ray: &Ray, t_min: f64, t_max: f64) -> Option<HitRecord>;
}
Since we’ll return either Some(HitRecord)
or None
in Hittable::hit()
, we can remove Option
and require all fields of the HitRecord
struct to be initialized when creating. We can also remove its Default
trait and the new()
function:
use std::sync::Arc;
use crate::material::Material;
use crate::ray::Ray;
use crate::vec3::{self, Point3, Vec3};
pub struct HitRecord {
pub p: Point3,
pub normal: Vec3,
pub mat: Arc<dyn Material>,
pub t: f64,
pub front_face: bool,
}
impl HitRecord {
pub fn set_face_normal(&mut self, r: &Ray, outward_normal: Vec3) {
self.front_face = vec3::dot(r.direction(), outward_normal) < 0.0;
self.normal = if self.front_face {
outward_normal
} else {
-outward_normal
};
}
}
pub trait Hittable: Send + Sync {
fn hit(&self, ray: &Ray, t_min: f64, t_max: f64) -> Option<HitRecord>;
}
We refactor the HittableList
implementation using the if let
expression:
use crate::hittable::{HitRecord, Hittable};
use crate::ray::Ray;
#[derive(Default)]
pub struct HittableList {
objects: Vec<Box<dyn Hittable>>,
}
impl HittableList {
pub fn new() -> HittableList {
Default::default()
}
pub fn add(&mut self, object: Box<dyn Hittable>) {
self.objects.push(object);
}
}
impl Hittable for HittableList {
fn hit(&self, ray: &Ray, t_min: f64, t_max: f64) -> Option<HitRecord> {
let mut temp_rec = None;
let mut closest_so_far = t_max;
for object in &self.objects {
if let Some(rec) = object.hit(ray, t_min, closest_so_far) {
closest_so_far = rec.t;
temp_rec = Some(rec);
}
}
temp_rec
}
}
And modify the Sphere
implementation:
use std::sync::Arc;
use crate::hittable::{HitRecord, Hittable};
use crate::material::Material;
use crate::ray::Ray;
use crate::vec3::{self, Point3};
pub struct Sphere {
center: Point3,
radius: f64,
mat: Arc<dyn Material>,
}
impl Sphere {
pub fn new(cen: Point3, r: f64, m: Arc<dyn Material>) -> Sphere {
Sphere {
center: cen,
radius: r,
mat: m,
}
}
}
impl Hittable for Sphere {
fn hit(&self, r: &Ray, t_min: f64, t_max: f64) -> Option<HitRecord> {
let oc = r.origin() - self.center;
let a = r.direction().length_squared();
let half_b = vec3::dot(oc, r.direction());
let c = oc.length_squared() - self.radius * self.radius;
let discriminant = half_b * half_b - a * c;
if discriminant < 0.0 {
return None;
}
let sqrt_d = f64::sqrt(discriminant);
// Find the nearest root that lies in the acceptable range
let mut root = (-half_b - sqrt_d) / a;
if root <= t_min || t_max <= root {
root = (-half_b + sqrt_d) / a;
if root <= t_min || t_max <= root {
return None;
}
}
let mut rec = HitRecord {
t: root,
p: r.at(root),
mat: self.mat.clone(),
normal: Default::default(),
front_face: Default::default(),
};
let outward_normal = (rec.p - self.center) / self.radius;
rec.set_face_normal(r, outward_normal);
Some(rec)
}
}
Similar to the Hittable
trait, we can refactor the scatter()
function to return Option<ScatterRecord>
:
use crate::color::Color;
use crate::hittable::HitRecord;
use crate::ray::Ray;
use crate::{common, vec3};
pub struct ScatterRecord {
pub attenuation: Color,
pub scattered: Ray,
}
pub trait Material: Send + Sync {
fn scatter(&self, r_in: &Ray, rec: &HitRecord) -> Option<ScatterRecord>;
}
pub struct Lambertian {
albedo: Color,
}
impl Lambertian {
pub fn new(a: Color) -> Lambertian {
Lambertian { albedo: a }
}
}
impl Material for Lambertian {
fn scatter(
&self,
_r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool {
let mut scatter_direction = rec.normal + vec3::random_unit_vector();
// Catch degenerate scatter direction
if scatter_direction.near_zero() {
scatter_direction = rec.normal;
}
*attenuation = self.albedo;
*scattered = Ray::new(rec.p, scatter_direction);
true
}
}
pub struct Metal {
albedo: Color,
fuzz: f64,
}
impl Metal {
pub fn new(a: Color, f: f64) -> Metal {
Metal {
albedo: a,
fuzz: if f < 1.0 { f } else { 1.0 },
}
}
}
impl Material for Metal {
fn scatter(
&self,
r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool {
let reflected = vec3::reflect(vec3::unit_vector(r_in.direction()), rec.normal);
*attenuation = self.albedo;
*scattered = Ray::new(rec.p, reflected + self.fuzz * vec3::random_in_unit_sphere());
vec3::dot(scattered.direction(), rec.normal) > 0.0
}
}
pub struct Dielectric {
ir: f64, // Index of refraction
}
impl Dielectric {
pub fn new(index_of_refraction: f64) -> Dielectric {
Dielectric {
ir: index_of_refraction,
}
}
fn reflectance(cosine: f64, ref_idx: f64) -> f64 {
// Use Schlick's approximation for reflectance
let mut r0 = (1.0 - ref_idx) / (1.0 + ref_idx);
r0 = r0 * r0;
r0 + (1.0 - r0) * f64::powf(1.0 - cosine, 5.0)
}
}
impl Material for Dielectric {
fn scatter(
&self,
r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool {
let refraction_ratio = if rec.front_face {
1.0 / self.ir
} else {
self.ir
};
let unit_direction = vec3::unit_vector(r_in.direction());
let cos_theta = f64::min(vec3::dot(-unit_direction, rec.normal), 1.0);
let sin_theta = f64::sqrt(1.0 - cos_theta * cos_theta);
let cannot_refract = refraction_ratio * sin_theta > 1.0;
let direction = if cannot_refract
|| Self::reflectance(cos_theta, refraction_ratio) > common::random_double()
{
vec3::reflect(unit_direction, rec.normal)
} else {
vec3::refract(unit_direction, rec.normal, refraction_ratio)
};
*attenuation = Color::new(1.0, 1.0, 1.0);
*scattered = Ray::new(rec.p, direction);
true
}
}
And modify the implementations:
use crate::color::Color;
use crate::hittable::HitRecord;
use crate::ray::Ray;
use crate::{common, vec3};
pub struct ScatterRecord {
pub attenuation: Color,
pub scattered: Ray,
}
pub trait Material: Send + Sync {
fn scatter(&self, r_in: &Ray, rec: &HitRecord) -> Option<ScatterRecord>;
}
pub struct Lambertian {
albedo: Color,
}
impl Lambertian {
pub fn new(a: Color) -> Lambertian {
Lambertian { albedo: a }
}
}
impl Material for Lambertian {
fn scatter(&self, _r_in: &Ray, rec: &HitRecord) -> Option<ScatterRecord> {
let mut scatter_direction = rec.normal + vec3::random_unit_vector();
// Catch degenerate scatter direction
if scatter_direction.near_zero() {
scatter_direction = rec.normal;
}
Some(ScatterRecord {
attenuation: self.albedo,
scattered: Ray::new(rec.p, scatter_direction),
})
}
}
pub struct Metal {
albedo: Color,
fuzz: f64,
}
impl Metal {
pub fn new(a: Color, f: f64) -> Metal {
Metal {
albedo: a,
fuzz: if f < 1.0 { f } else { 1.0 },
}
}
}
impl Material for Metal {
fn scatter(&self, r_in: &Ray, rec: &HitRecord) -> Option<ScatterRecord> {
let reflected = vec3::reflect(vec3::unit_vector(r_in.direction()), rec.normal);
let scattered = Ray::new(rec.p, reflected + self.fuzz * vec3::random_in_unit_sphere());
if vec3::dot(scattered.direction(), rec.normal) > 0.0 {
Some(ScatterRecord {
attenuation: self.albedo,
scattered,
})
} else {
None
}
}
}
pub struct Dielectric {
ir: f64, // Index of refraction
}
impl Dielectric {
pub fn new(index_of_refraction: f64) -> Dielectric {
Dielectric {
ir: index_of_refraction,
}
}
fn reflectance(cosine: f64, ref_idx: f64) -> f64 {
// Use Schlick's approximation for reflectance
let mut r0 = (1.0 - ref_idx) / (1.0 + ref_idx);
r0 = r0 * r0;
r0 + (1.0 - r0) * f64::powf(1.0 - cosine, 5.0)
}
}
impl Material for Dielectric {
fn scatter(&self, r_in: &Ray, rec: &HitRecord) -> Option<ScatterRecord> {
let refraction_ratio = if rec.front_face {
1.0 / self.ir
} else {
self.ir
};
let unit_direction = vec3::unit_vector(r_in.direction());
let cos_theta = f64::min(vec3::dot(-unit_direction, rec.normal), 1.0);
let sin_theta = f64::sqrt(1.0 - cos_theta * cos_theta);
let cannot_refract = refraction_ratio * sin_theta > 1.0;
let direction = if cannot_refract
|| Self::reflectance(cos_theta, refraction_ratio) > common::random_double()
{
vec3::reflect(unit_direction, rec.normal)
} else {
vec3::refract(unit_direction, rec.normal, refraction_ratio)
};
Some(ScatterRecord {
attenuation: Color::new(1.0, 1.0, 1.0),
scattered: Ray::new(rec.p, direction),
})
}
}
Finally, we refactor the main()
function:
mod camera;
mod color;
mod common;
mod hittable;
mod hittable_list;
mod material;
mod ray;
mod sphere;
mod vec3;
use std::io;
use std::sync::Arc;
use rayon::prelude::*;
use camera::Camera;
use color::Color;
use hittable::Hittable;
use hittable_list::HittableList;
use material::{Dielectric, Lambertian, Metal};
use ray::Ray;
use sphere::Sphere;
use vec3::Point3;
fn ray_color(r: &Ray, world: &dyn Hittable, depth: i32) -> Color {
// If we've exceeded the ray bounce limit, no more light is gathered
if depth <= 0 {
return Color::new(0.0, 0.0, 0.0);
}
if let Some(hit_rec) = world.hit(r, 0.001, common::INFINITY) {
if let Some(scatter_rec) = hit_rec.mat.scatter(r, &hit_rec) {
return scatter_rec.attenuation * ray_color(&scatter_rec.scattered, world, depth - 1);
}
return Color::new(0.0, 0.0, 0.0);
}
let unit_direction = vec3::unit_vector(r.direction());
let t = 0.5 * (unit_direction.y() + 1.0);
(1.0 - t) * Color::new(1.0, 1.0, 1.0) + t * Color::new(0.5, 0.7, 1.0)
}
fn random_scene() -> HittableList {
let mut world = HittableList::new();
let ground_material = Arc::new(Lambertian::new(Color::new(0.5, 0.5, 0.5)));
world.add(Box::new(Sphere::new(
Point3::new(0.0, -1000.0, 0.0),
1000.0,
ground_material,
)));
for a in -11..11 {
for b in -11..11 {
let choose_mat = common::random_double();
let center = Point3::new(
a as f64 + 0.9 * common::random_double(),
0.2,
b as f64 + 0.9 * common::random_double(),
);
if (center - Point3::new(4.0, 0.2, 0.0)).length() > 0.9 {
if choose_mat < 0.8 {
// Diffuse
let albedo = Color::random() * Color::random();
let sphere_material = Arc::new(Lambertian::new(albedo));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
} else if choose_mat < 0.95 {
// Metal
let albedo = Color::random_range(0.5, 1.0);
let fuzz = common::random_double_range(0.0, 0.5);
let sphere_material = Arc::new(Metal::new(albedo, fuzz));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
} else {
// Glass
let sphere_material = Arc::new(Dielectric::new(1.5));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
}
}
}
}
let material1 = Arc::new(Dielectric::new(1.5));
world.add(Box::new(Sphere::new(
Point3::new(0.0, 1.0, 0.0),
1.0,
material1,
)));
let material2 = Arc::new(Lambertian::new(Color::new(0.4, 0.2, 0.1)));
world.add(Box::new(Sphere::new(
Point3::new(-4.0, 1.0, 0.0),
1.0,
material2,
)));
let material3 = Arc::new(Metal::new(Color::new(0.7, 0.6, 0.5), 0.0));
world.add(Box::new(Sphere::new(
Point3::new(4.0, 1.0, 0.0),
1.0,
material3,
)));
world
}
fn main() {
// Image
const ASPECT_RATIO: f64 = 3.0 / 2.0;
const IMAGE_WIDTH: i32 = 1200;
const IMAGE_HEIGHT: i32 = (IMAGE_WIDTH as f64 / ASPECT_RATIO) as i32;
const SAMPLES_PER_PIXEL: i32 = 500;
const MAX_DEPTH: i32 = 50;
// World
let world = random_scene();
// Camera
let lookfrom = Point3::new(13.0, 2.0, 3.0);
let lookat = Point3::new(0.0, 0.0, 0.0);
let vup = Point3::new(0.0, 1.0, 0.0);
let dist_to_focus = 10.0;
let aperture = 0.1;
let cam = Camera::new(
lookfrom,
lookat,
vup,
20.0,
ASPECT_RATIO,
aperture,
dist_to_focus,
);
// Render
print!("P3\n{} {}\n255\n", IMAGE_WIDTH, IMAGE_HEIGHT);
for j in (0..IMAGE_HEIGHT).rev() {
eprint!("\rScanlines remaining: {} ", j);
let pixel_colors: Vec<_> = (0..IMAGE_WIDTH)
.into_par_iter()
.map(|i| {
let mut pixel_color = Color::new(0.0, 0.0, 0.0);
for _ in 0..SAMPLES_PER_PIXEL {
let u = ((i as f64) + common::random_double()) / (IMAGE_WIDTH - 1) as f64;
let v = ((j as f64) + common::random_double()) / (IMAGE_HEIGHT - 1) as f64;
let r = cam.get_ray(u, v);
pixel_color += ray_color(&r, &world, MAX_DEPTH);
}
pixel_color
})
.collect();
for pixel_color in pixel_colors {
color::write_color(&mut io::stdout(), pixel_color, SAMPLES_PER_PIXEL);
}
}
eprint!("\nDone.\n");
}
The code now looks much nicer!