Import Libraries

In this section, we import necessary libraries such as PyTorch, tqdm for progress visualization, odak for optical data processing, and other relevant libraries.

1 import torch
2 import tqdm
3 import odak
4 import matplotlib.pyplot as plt
5 import cmf
6 import perceived_spectrum
7 import warnings
8 warnings.filterwarnings("ignore", category=FutureWarning)

Gaussian Function

This function defines a Gaussian curve with parameters for mean, standard deviation, and amplitude. It’s used to generate Gaussian-like spectra for the primaries.

10 def gaussian(x, mean, std, amplitude):
11    return amplitude * torch.exp(-0.5 * ((x - mean) / std) ** 2)

Evaluate Function

The evaluate function calculates the loss between the predicted LMS values and the target LMS values. It also includes penalties for negative spectrum values and spectrum peaks, which helps to regularize the spectrum.

The loss function is defined as:

\[ \text{Loss} = w_0 \cdot \text{LMS\_L2} + w_1 \cdot \text{Spectrum\_Negative} + w_2 \cdot \text{Spectrum\_Peak} \]

Where:

  • \(\text{LMS\_L2} = \text{MSE}(\text{predicted}, \text{target})\) : This term represents the mean squared error between the predicted and target LMS values.
  • \(\text{Spectrum\_Negative} = \sum (\text{spectrum}[i] < 0)\) : This term penalizes any negative values in the spectrum.
  • \(\text{Spectrum\_Peak} = \sum (\text{spectrum}[i] > 1)\) : This term penalizes spectrum values that exceed 1.
  • w_0, w_1, w_2 : These are the weights applied to each term in the loss function, adjusting the relative importance of each term.
13 def evaluate(predicted, target, spectrum, target_stds=[0.2, 0.2, 0.2], weights=[1., 1., 1., 0., 0.]):
14    image_lms_l2 = torch.nn.MSELoss()(predicted, target)
15    spectrum_negative = torch.abs(torch.sum(spectrum[spectrum < 0]))
16    spectrum_peak = torch.sum(spectrum[spectrum > 1])
17    loss = weights[0] * image_lms_l2 + weights[1] * spectrum_negative + weights[2] * spectrum_peak
18    return loss

Get Primaries

This function generates a spectrum for each primary by summing multiple Gaussian functions.

21 def get_primaries(number_of_primaries, means, stds, amplitudes, wavelengths, device):
22    spectrum = torch.zeros((number_of_primaries, wavelengths.shape[1]), device=device)
23    relu = torch.nn.ReLU()
24    for i in range(number_of_primaries):
25        spectrum[i] = torch.sum(gaussian(x=wavelengths, mean=means[i], std=stds[i], amplitude=amplitudes[i]), dim=0)
26    return spectrum

Optimize Primaries

This is the core optimization function that tunes the parameters (means, stds, amplitudes) of the primary spectra to match the target LMS values using gradient descent with Adam optimizer.

29 def optimize_primaries(wavelengths, target_image_rgb, target_image_lms, number_of_primaries=3, number_of_iterations=100, learning_rate=1e-2, p_count=3, device=torch.device('cpu')):
30    means = torch.rand(number_of_primaries, p_count, 1, requires_grad=False, device=device) * torch.tensor([440., 532., 620.], device=device).unsqueeze(-1).unsqueeze(-1)
31    stds = torch.rand(number_of_primaries, p_count, 1, requires_grad=False, device=device) * 10.
32    amplitudes = torch.rand(number_of_primaries, p_count, 1, requires_grad=False, device=device) * 1e-2
33    means.requires_grad = True
34    stds.requires_grad = True
35    amplitudes.requires_grad = True
36    optimizer = torch.optim.Adam([means, stds, amplitudes], lr=learning_rate)
37    wavelengths = wavelengths.unsqueeze(0).repeat(p_count, 1)
38    t = tqdm.tqdm(range(number_of_iterations), leave=False, dynamic_ncols=True)
39    for iteration in t:
40        optimizer.zero_grad()
41        optimized_spectrum = get_primaries(number_of_primaries, means, stds, amplitudes, wavelengths, device=device)
42        estimation_lms = convert_rgb_to_lms(target_image_rgb, optimized_spectrum)
43        loss = evaluate(estimation_lms, target_image_lms, optimized_spectrum)
44        loss.backward(retain_graph=True)
45        optimizer.step()
46    return optimized_spectrum.detach()

Convert RGB to LMS Function

The convert_rgb_to_lms function converts an image in RGB color space to the LMS color space using the sensitivity of the L, M, and S cones of the human visual system.

The LMS cone response is calculated as:

Where:

  • RGB_i: The i-th color channel (Red, Green, or Blue) of the image.
  • Spectrum_i: The spectrum of the i-th primary.
  • Sensitivity_i: The sensitivity of the L, M, and S cones for each wavelength.

The function performs the following steps:

  1. Multiplies each RGB channel by its corresponding spectrum.
  2. Multiplies the resulting image spectrum by the cone sensitivities for the L, M, and S cones.
  3. Sums the results to get the LMS representation.
48 def convert_rgb_to_lms(image_rgb, spectrum):
49    wavelengths, l_cone_sensitivity, m_cone_sensitivity, s_cone_sensitivity = cmf.load_LMS_data()
50    image_spectrum = torch.stack((
51        image_rgb[0] * spectrum[0].unsqueeze(-1).unsqueeze(-1),
52        image_rgb[1] * spectrum[1].unsqueeze(-1).unsqueeze(-1),
53        image_rgb[2] * spectrum[2].unsqueeze(-1).unsqueeze(-1)
54    ), dim=0)
55    image_lms = torch.stack((
56        image_spectrum[0] * l_cone_sensitivity[0].unsqueeze(-1).unsqueeze(-1),
57        image_spectrum[1] * m_cone_sensitivity[1].unsqueeze(-1).unsqueeze(-1),
58        image_spectrum[2] * s_cone_sensitivity[2].unsqueeze(-1).unsqueeze(-1)
59    ), dim=0)
60    return torch.sum(image_lms, dim=1)

Main Function

The main function is the entry point of the program. It loads LMS data, generates a random target RGB image, and calls the optimize_primaries function to find the optimized display primaries.

63 def main():
64    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65    number_of_primaries = 3
66    wavelengths, l_cone_sensitivity, m_cone_sensitivity, s_cone_sensitivity = cmf.load_LMS_data()
67    target_spectrum = torch.as_tensor(perceived_spectrum.get_intensity(wavelengths.cpu().numpy(), display_type='iPadPro'), dtype=torch.float32, device=device).squeeze(-1).permute(1, 0)
68    target_image_rgb = torch.rand(3, 100, 100, device=device)
69    target_image_lms = convert_rgb_to_lms(target_image_rgb, target_spectrum)
70    optimized_primaries = optimize_primaries(
71        wavelengths,
72        target_image_rgb,
73        target_image_lms,
74        number_of_primaries=number_of_primaries,
75        p_count=4,
76        learning_rate=1e-6,
77        number_of_iterations=1000000,
78        device=device
79    )
80    plt.figure()
81    print(torch.argmax(optimized_primaries, dim=1) + wavelengths[0])
82    print(torch.argmax(target_spectrum, dim=1) + wavelengths[0])
83    for i in range(number_of_primaries):
84        plt.plot(optimized_primaries[i].cpu().numpy())
85    plt.figure()
86    for i in range(number_of_primaries):
87        plt.plot(target_spectrum[i].cpu().numpy())
88    plt.show()
89    return True

Run the Main Function

Here, the main function is called to execute the optimization and spectrum generation.

91 if __name__ == '__main__':
92    main()