Fearless Parallelism with Rayon
One of Rust’s greatest strengths is that Rust guarantees there is no data race if the code compiles — we can parallelize our code with confidence.
We’ll use the rayon (opens in a new tab) crate, a data parallelism library, to parallelize the code and utilize all CPU cores of our machine.
Let’s install the crate:
cargo add rayon
We modify the inner loop, which generates the pixels of each horizontal scanline:
mod camera;
mod color;
mod common;
mod hittable;
mod hittable_list;
mod material;
mod ray;
mod sphere;
mod vec3;
use std::io;
use std::rc::Rc;
use rayon::prelude::*;
use camera::Camera;
use color::Color;
use hittable::{HitRecord, Hittable};
use hittable_list::HittableList;
use material::{Dielectric, Lambertian, Metal};
use ray::Ray;
use sphere::Sphere;
use vec3::Point3;
fn ray_color(r: &Ray, world: &dyn Hittable, depth: i32) -> Color {
// If we've exceeded the ray bounce limit, no more light is gathered
if depth <= 0 {
return Color::new(0.0, 0.0, 0.0);
}
let mut rec = HitRecord::new();
if world.hit(r, 0.001, common::INFINITY, &mut rec) {
let mut attenuation = Color::default();
let mut scattered = Ray::default();
if rec
.mat
.as_ref()
.unwrap()
.scatter(r, &rec, &mut attenuation, &mut scattered)
{
return attenuation * ray_color(&scattered, world, depth - 1);
}
return Color::new(0.0, 0.0, 0.0);
}
let unit_direction = vec3::unit_vector(r.direction());
let t = 0.5 * (unit_direction.y() + 1.0);
(1.0 - t) * Color::new(1.0, 1.0, 1.0) + t * Color::new(0.5, 0.7, 1.0)
}
fn random_scene() -> HittableList {
let mut world = HittableList::new();
let ground_material = Rc::new(Lambertian::new(Color::new(0.5, 0.5, 0.5)));
world.add(Box::new(Sphere::new(
Point3::new(0.0, -1000.0, 0.0),
1000.0,
ground_material,
)));
for a in -11..11 {
for b in -11..11 {
let choose_mat = common::random_double();
let center = Point3::new(
a as f64 + 0.9 * common::random_double(),
0.2,
b as f64 + 0.9 * common::random_double(),
);
if (center - Point3::new(4.0, 0.2, 0.0)).length() > 0.9 {
if choose_mat < 0.8 {
// Diffuse
let albedo = Color::random() * Color::random();
let sphere_material = Rc::new(Lambertian::new(albedo));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
} else if choose_mat < 0.95 {
// Metal
let albedo = Color::random_range(0.5, 1.0);
let fuzz = common::random_double_range(0.0, 0.5);
let sphere_material = Rc::new(Metal::new(albedo, fuzz));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
} else {
// Glass
let sphere_material = Rc::new(Dielectric::new(1.5));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
}
}
}
}
let material1 = Rc::new(Dielectric::new(1.5));
world.add(Box::new(Sphere::new(
Point3::new(0.0, 1.0, 0.0),
1.0,
material1,
)));
let material2 = Rc::new(Lambertian::new(Color::new(0.4, 0.2, 0.1)));
world.add(Box::new(Sphere::new(
Point3::new(-4.0, 1.0, 0.0),
1.0,
material2,
)));
let material3 = Rc::new(Metal::new(Color::new(0.7, 0.6, 0.5), 0.0));
world.add(Box::new(Sphere::new(
Point3::new(4.0, 1.0, 0.0),
1.0,
material3,
)));
world
}
fn main() {
// Image
const ASPECT_RATIO: f64 = 3.0 / 2.0;
const IMAGE_WIDTH: i32 = 1200;
const IMAGE_HEIGHT: i32 = (IMAGE_WIDTH as f64 / ASPECT_RATIO) as i32;
const SAMPLES_PER_PIXEL: i32 = 500;
const MAX_DEPTH: i32 = 50;
// World
let world = random_scene();
// Camera
let lookfrom = Point3::new(13.0, 2.0, 3.0);
let lookat = Point3::new(0.0, 0.0, 0.0);
let vup = Point3::new(0.0, 1.0, 0.0);
let dist_to_focus = 10.0;
let aperture = 0.1;
let cam = Camera::new(
lookfrom,
lookat,
vup,
20.0,
ASPECT_RATIO,
aperture,
dist_to_focus,
);
// Render
print!("P3\n{} {}\n255\n", IMAGE_WIDTH, IMAGE_HEIGHT);
for j in (0..IMAGE_HEIGHT).rev() {
eprint!("\rScanlines remaining: {} ", j);
let pixel_colors: Vec<_> = (0..IMAGE_WIDTH)
.into_par_iter()
.map(|i| {
let mut pixel_color = Color::new(0.0, 0.0, 0.0);
for _ in 0..SAMPLES_PER_PIXEL {
let u = ((i as f64) + common::random_double()) / (IMAGE_WIDTH - 1) as f64;
let v = ((j as f64) + common::random_double()) / (IMAGE_HEIGHT - 1) as f64;
let r = cam.get_ray(u, v);
pixel_color += ray_color(&r, &world, MAX_DEPTH);
}
pixel_color
})
.collect();
for pixel_color in pixel_colors {
color::write_color(&mut io::stdout(), pixel_color, SAMPLES_PER_PIXEL);
}
}
eprint!("\nDone.\n");
}
Here’s what the code above does for each scanline:
-
Use
into_par_iter()
(opens in a new tab) to convert theRange
into aParallelIterator
(opens in a new tab). -
Use
map()
(opens in a new tab) with a closure specifying how the result of each pixel is calculated when iteratingi
. It will return another type implementingParallelIterator
. -
Use
collect()
(opens in a new tab) to iterate, execute, and gather the results into aVec<Color>
. -
Write each pixel out in the
Vec<Color>
.
It is worth noting that the collect()
(opens in a new tab) function has a generic type in the return position. This means the caller gets to choose any implemented return type (e.g., Vec
(opens in a new tab), LinkedList
(opens in a new tab), etc.). That’s why we have to explicitly annotate the type Vec<_>
in the code since the compiler cannot infer the type in this case. Alternatively, we can insert a turbofish collect::<Vec<_>>()
to tell the compiler which type we want.
And that’s it. We’ve finished modifying the code with only a few lines of change!
However, it does not actually work yet. If you run cargo check
, the compiler will show:
error[E0277]: `(dyn Hittable + 'static)` cannot be shared between threads safely
The reason is that we have not made our code thread-safe.
One way to fix the error above is to replace every occurrence of dyn Hittable
with dyn Hittable + Sync
, telling the compiler we will pass Hittable
trait objects that also implement the Sync
trait. However, this will make the code look repetitive all over the place.
Another way to fix the error is to specify “supertraits” for Hittable
trait, which ensures that every type implementing the Hittable
trait also implements the supertraits.
For full thread safety, let’s make the supertraits of the Hittable
trait both Send
(safe to be sent between threads) and Sync
(safe to be shared between threads):
use std::rc::Rc;
use crate::material::Material;
use crate::ray::Ray;
use crate::vec3::{self, Point3, Vec3};
#[derive(Clone, Default)]
pub struct HitRecord {
pub p: Point3,
pub normal: Vec3,
pub mat: Option<Rc<dyn Material>>,
pub t: f64,
pub front_face: bool,
}
impl HitRecord {
pub fn new() -> HitRecord {
Default::default()
}
pub fn set_face_normal(&mut self, r: &Ray, outward_normal: Vec3) {
self.front_face = vec3::dot(r.direction(), outward_normal) < 0.0;
self.normal = if self.front_face {
outward_normal
} else {
-outward_normal
};
}
}
pub trait Hittable: Send + Sync {
fn hit(&self, ray: &Ray, t_min: f64, t_max: f64, rec: &mut HitRecord) -> bool;
}
And we do the same for the Material
trait:
use crate::color::Color;
use crate::hittable::HitRecord;
use crate::ray::Ray;
use crate::{common, vec3};
pub trait Material: Send + Sync {
fn scatter(
&self,
r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool;
}
pub struct Lambertian {
albedo: Color,
}
impl Lambertian {
pub fn new(a: Color) -> Lambertian {
Lambertian { albedo: a }
}
}
impl Material for Lambertian {
fn scatter(
&self,
_r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool {
let mut scatter_direction = rec.normal + vec3::random_unit_vector();
// Catch degenerate scatter direction
if scatter_direction.near_zero() {
scatter_direction = rec.normal;
}
*attenuation = self.albedo;
*scattered = Ray::new(rec.p, scatter_direction);
true
}
}
pub struct Metal {
albedo: Color,
fuzz: f64,
}
impl Metal {
pub fn new(a: Color, f: f64) -> Metal {
Metal {
albedo: a,
fuzz: if f < 1.0 { f } else { 1.0 },
}
}
}
impl Material for Metal {
fn scatter(
&self,
r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool {
let reflected = vec3::reflect(vec3::unit_vector(r_in.direction()), rec.normal);
*attenuation = self.albedo;
*scattered = Ray::new(rec.p, reflected + self.fuzz * vec3::random_in_unit_sphere());
vec3::dot(scattered.direction(), rec.normal) > 0.0
}
}
pub struct Dielectric {
ir: f64, // Index of refraction
}
impl Dielectric {
pub fn new(index_of_refraction: f64) -> Dielectric {
Dielectric {
ir: index_of_refraction,
}
}
fn reflectance(cosine: f64, ref_idx: f64) -> f64 {
// Use Schlick's approximation for reflectance
let mut r0 = (1.0 - ref_idx) / (1.0 + ref_idx);
r0 = r0 * r0;
r0 + (1.0 - r0) * f64::powf(1.0 - cosine, 5.0)
}
}
impl Material for Dielectric {
fn scatter(
&self,
r_in: &Ray,
rec: &HitRecord,
attenuation: &mut Color,
scattered: &mut Ray,
) -> bool {
let refraction_ratio = if rec.front_face {
1.0 / self.ir
} else {
self.ir
};
let unit_direction = vec3::unit_vector(r_in.direction());
let cos_theta = f64::min(vec3::dot(-unit_direction, rec.normal), 1.0);
let sin_theta = f64::sqrt(1.0 - cos_theta * cos_theta);
let cannot_refract = refraction_ratio * sin_theta > 1.0;
let direction = if cannot_refract
|| Self::reflectance(cos_theta, refraction_ratio) > common::random_double()
{
vec3::reflect(unit_direction, rec.normal)
} else {
vec3::refract(unit_direction, rec.normal, refraction_ratio)
};
*attenuation = Color::new(1.0, 1.0, 1.0);
*scattered = Ray::new(rec.p, direction);
true
}
}
Now, if we run cargo check
, we’ll get similar errors but for the Rc
type:
error[E0277]: `Rc<(dyn Material + 'static)>` cannot be shared between threads safely
error[E0277]: `Rc<(dyn Material + 'static)>` cannot be sent between threads safely
While most primitive types (opens in a new tab) in Rust are both Send
and Sync
, Rc
is not one of them. To fix the errors, we can replace Rc
with Arc
, which is both Send
and Sync
and thus can be shared and sent between threads safely.
Send
and Sync
are marker traits built into Rust’s type system. Any type composed entirely of Send
types is automatically Send
. The same applies to the Sync
trait.
Let’s replace Rc
with Arc
in all modules:
use std::sync::Arc;
use crate::hittable::{HitRecord, Hittable};
use crate::material::Material;
use crate::ray::Ray;
use crate::vec3::{self, Point3};
pub struct Sphere {
center: Point3,
radius: f64,
mat: Arc<dyn Material>,
}
impl Sphere {
pub fn new(cen: Point3, r: f64, m: Arc<dyn Material>) -> Sphere {
Sphere {
center: cen,
radius: r,
mat: m,
}
}
}
impl Hittable for Sphere {
fn hit(&self, r: &Ray, t_min: f64, t_max: f64, rec: &mut HitRecord) -> bool {
let oc = r.origin() - self.center;
let a = r.direction().length_squared();
let half_b = vec3::dot(oc, r.direction());
let c = oc.length_squared() - self.radius * self.radius;
let discriminant = half_b * half_b - a * c;
if discriminant < 0.0 {
return false;
}
let sqrt_d = f64::sqrt(discriminant);
// Find the nearest root that lies in the acceptable range
let mut root = (-half_b - sqrt_d) / a;
if root <= t_min || t_max <= root {
root = (-half_b + sqrt_d) / a;
if root <= t_min || t_max <= root {
return false;
}
}
rec.t = root;
rec.p = r.at(rec.t);
let outward_normal = (rec.p - self.center) / self.radius;
rec.set_face_normal(r, outward_normal);
rec.mat = Some(self.mat.clone());
true
}
}
use std::sync::Arc;
use crate::material::Material;
use crate::ray::Ray;
use crate::vec3::{self, Point3, Vec3};
#[derive(Clone, Default)]
pub struct HitRecord {
pub p: Point3,
pub normal: Vec3,
pub mat: Option<Arc<dyn Material>>,
pub t: f64,
pub front_face: bool,
}
impl HitRecord {
pub fn new() -> HitRecord {
Default::default()
}
pub fn set_face_normal(&mut self, r: &Ray, outward_normal: Vec3) {
self.front_face = vec3::dot(r.direction(), outward_normal) < 0.0;
self.normal = if self.front_face {
outward_normal
} else {
-outward_normal
};
}
}
pub trait Hittable: Send + Sync {
fn hit(&self, ray: &Ray, t_min: f64, t_max: f64, rec: &mut HitRecord) -> bool;
}
mod camera;
mod color;
mod common;
mod hittable;
mod hittable_list;
mod material;
mod ray;
mod sphere;
mod vec3;
use std::io;
use std::sync::Arc;
use rayon::prelude::*;
use camera::Camera;
use color::Color;
use hittable::{HitRecord, Hittable};
use hittable_list::HittableList;
use material::{Dielectric, Lambertian, Metal};
use ray::Ray;
use sphere::Sphere;
use vec3::Point3;
fn ray_color(r: &Ray, world: &dyn Hittable, depth: i32) -> Color {
// If we've exceeded the ray bounce limit, no more light is gathered
if depth <= 0 {
return Color::new(0.0, 0.0, 0.0);
}
let mut rec = HitRecord::new();
if world.hit(r, 0.001, common::INFINITY, &mut rec) {
let mut attenuation = Color::default();
let mut scattered = Ray::default();
if rec
.mat
.as_ref()
.unwrap()
.scatter(r, &rec, &mut attenuation, &mut scattered)
{
return attenuation * ray_color(&scattered, world, depth - 1);
}
return Color::new(0.0, 0.0, 0.0);
}
let unit_direction = vec3::unit_vector(r.direction());
let t = 0.5 * (unit_direction.y() + 1.0);
(1.0 - t) * Color::new(1.0, 1.0, 1.0) + t * Color::new(0.5, 0.7, 1.0)
}
fn random_scene() -> HittableList {
let mut world = HittableList::new();
let ground_material = Arc::new(Lambertian::new(Color::new(0.5, 0.5, 0.5)));
world.add(Box::new(Sphere::new(
Point3::new(0.0, -1000.0, 0.0),
1000.0,
ground_material,
)));
for a in -11..11 {
for b in -11..11 {
let choose_mat = common::random_double();
let center = Point3::new(
a as f64 + 0.9 * common::random_double(),
0.2,
b as f64 + 0.9 * common::random_double(),
);
if (center - Point3::new(4.0, 0.2, 0.0)).length() > 0.9 {
if choose_mat < 0.8 {
// Diffuse
let albedo = Color::random() * Color::random();
let sphere_material = Arc::new(Lambertian::new(albedo));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
} else if choose_mat < 0.95 {
// Metal
let albedo = Color::random_range(0.5, 1.0);
let fuzz = common::random_double_range(0.0, 0.5);
let sphere_material = Arc::new(Metal::new(albedo, fuzz));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
} else {
// Glass
let sphere_material = Arc::new(Dielectric::new(1.5));
world.add(Box::new(Sphere::new(center, 0.2, sphere_material)));
}
}
}
}
let material1 = Arc::new(Dielectric::new(1.5));
world.add(Box::new(Sphere::new(
Point3::new(0.0, 1.0, 0.0),
1.0,
material1,
)));
let material2 = Arc::new(Lambertian::new(Color::new(0.4, 0.2, 0.1)));
world.add(Box::new(Sphere::new(
Point3::new(-4.0, 1.0, 0.0),
1.0,
material2,
)));
let material3 = Arc::new(Metal::new(Color::new(0.7, 0.6, 0.5), 0.0));
world.add(Box::new(Sphere::new(
Point3::new(4.0, 1.0, 0.0),
1.0,
material3,
)));
world
}
fn main() {
// Image
const ASPECT_RATIO: f64 = 3.0 / 2.0;
const IMAGE_WIDTH: i32 = 1200;
const IMAGE_HEIGHT: i32 = (IMAGE_WIDTH as f64 / ASPECT_RATIO) as i32;
const SAMPLES_PER_PIXEL: i32 = 500;
const MAX_DEPTH: i32 = 50;
// World
let world = random_scene();
// Camera
let lookfrom = Point3::new(13.0, 2.0, 3.0);
let lookat = Point3::new(0.0, 0.0, 0.0);
let vup = Point3::new(0.0, 1.0, 0.0);
let dist_to_focus = 10.0;
let aperture = 0.1;
let cam = Camera::new(
lookfrom,
lookat,
vup,
20.0,
ASPECT_RATIO,
aperture,
dist_to_focus,
);
// Render
print!("P3\n{} {}\n255\n", IMAGE_WIDTH, IMAGE_HEIGHT);
for j in (0..IMAGE_HEIGHT).rev() {
eprint!("\rScanlines remaining: {} ", j);
let pixel_colors: Vec<_> = (0..IMAGE_WIDTH)
.into_par_iter()
.map(|i| {
let mut pixel_color = Color::new(0.0, 0.0, 0.0);
for _ in 0..SAMPLES_PER_PIXEL {
let u = ((i as f64) + common::random_double()) / (IMAGE_WIDTH - 1) as f64;
let v = ((j as f64) + common::random_double()) / (IMAGE_HEIGHT - 1) as f64;
let r = cam.get_ray(u, v);
pixel_color += ray_color(&r, &world, MAX_DEPTH);
}
pixel_color
})
.collect();
for pixel_color in pixel_colors {
color::write_color(&mut io::stdout(), pixel_color, SAMPLES_PER_PIXEL);
}
}
eprint!("\nDone.\n");
}
And this time, our code compiles!
If you try to generate the image again, it should be much faster and use all CPU cores available.