In this section, we import necessary libraries such as PyTorch, tqdm for progress visualization, odak for optical data processing, and other relevant libraries.
1 import torch
2 import tqdm
3 import odak
4 import matplotlib.pyplot as plt
5 import cmf
6 import perceived_spectrum
7 import warnings
8 warnings.filterwarnings("ignore", category=FutureWarning)
This function defines a Gaussian curve with parameters for mean, standard deviation, and amplitude. It’s used to generate Gaussian-like spectra for the primaries.
10 def gaussian(x, mean, std, amplitude):
11 return amplitude * torch.exp(-0.5 * ((x - mean) / std) ** 2)
The evaluate
function calculates the loss between the predicted LMS values and the target LMS values. It also includes penalties for negative spectrum values and spectrum peaks, which helps to regularize the spectrum.
The loss function is defined as:
\[ \text{Loss} = w_0 \cdot \text{LMS\_L2} + w_1 \cdot \text{Spectrum\_Negative} + w_2 \cdot \text{Spectrum\_Peak} \]
Where:
13 def evaluate(predicted, target, spectrum, target_stds=[0.2, 0.2, 0.2], weights=[1., 1., 1., 0., 0.]):
14 image_lms_l2 = torch.nn.MSELoss()(predicted, target)
15 spectrum_negative = torch.abs(torch.sum(spectrum[spectrum < 0]))
16 spectrum_peak = torch.sum(spectrum[spectrum > 1])
17 loss = weights[0] * image_lms_l2 + weights[1] * spectrum_negative + weights[2] * spectrum_peak
18 return loss
This function generates a spectrum for each primary by summing multiple Gaussian functions.
21 def get_primaries(number_of_primaries, means, stds, amplitudes, wavelengths, device):
22 spectrum = torch.zeros((number_of_primaries, wavelengths.shape[1]), device=device)
23 relu = torch.nn.ReLU()
24 for i in range(number_of_primaries):
25 spectrum[i] = torch.sum(gaussian(x=wavelengths, mean=means[i], std=stds[i], amplitude=amplitudes[i]), dim=0)
26 return spectrum
This is the core optimization function that tunes the parameters (means, stds, amplitudes) of the primary spectra to match the target LMS values using gradient descent with Adam optimizer.
29 def optimize_primaries(wavelengths, target_image_rgb, target_image_lms, number_of_primaries=3, number_of_iterations=100, learning_rate=1e-2, p_count=3, device=torch.device('cpu')):
30 means = torch.rand(number_of_primaries, p_count, 1, requires_grad=False, device=device) * torch.tensor([440., 532., 620.], device=device).unsqueeze(-1).unsqueeze(-1)
31 stds = torch.rand(number_of_primaries, p_count, 1, requires_grad=False, device=device) * 10.
32 amplitudes = torch.rand(number_of_primaries, p_count, 1, requires_grad=False, device=device) * 1e-2
33 means.requires_grad = True
34 stds.requires_grad = True
35 amplitudes.requires_grad = True
36 optimizer = torch.optim.Adam([means, stds, amplitudes], lr=learning_rate)
37 wavelengths = wavelengths.unsqueeze(0).repeat(p_count, 1)
38 t = tqdm.tqdm(range(number_of_iterations), leave=False, dynamic_ncols=True)
39 for iteration in t:
40 optimizer.zero_grad()
41 optimized_spectrum = get_primaries(number_of_primaries, means, stds, amplitudes, wavelengths, device=device)
42 estimation_lms = convert_rgb_to_lms(target_image_rgb, optimized_spectrum)
43 loss = evaluate(estimation_lms, target_image_lms, optimized_spectrum)
44 loss.backward(retain_graph=True)
45 optimizer.step()
46 return optimized_spectrum.detach()
The convert_rgb_to_lms
function converts an image in RGB color space to the LMS color space using the sensitivity of the L, M, and S cones of the human visual system.
The LMS cone response is calculated as:
Where:
The function performs the following steps:
48 def convert_rgb_to_lms(image_rgb, spectrum):
49 wavelengths, l_cone_sensitivity, m_cone_sensitivity, s_cone_sensitivity = cmf.load_LMS_data()
50 image_spectrum = torch.stack((
51 image_rgb[0] * spectrum[0].unsqueeze(-1).unsqueeze(-1),
52 image_rgb[1] * spectrum[1].unsqueeze(-1).unsqueeze(-1),
53 image_rgb[2] * spectrum[2].unsqueeze(-1).unsqueeze(-1)
54 ), dim=0)
55 image_lms = torch.stack((
56 image_spectrum[0] * l_cone_sensitivity[0].unsqueeze(-1).unsqueeze(-1),
57 image_spectrum[1] * m_cone_sensitivity[1].unsqueeze(-1).unsqueeze(-1),
58 image_spectrum[2] * s_cone_sensitivity[2].unsqueeze(-1).unsqueeze(-1)
59 ), dim=0)
60 return torch.sum(image_lms, dim=1)
The main
function is the entry point of the program. It loads LMS data, generates a random target RGB image, and calls the optimize_primaries
function to find the optimized display primaries.
63 def main():
64 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65 number_of_primaries = 3
66 wavelengths, l_cone_sensitivity, m_cone_sensitivity, s_cone_sensitivity = cmf.load_LMS_data()
67 target_spectrum = torch.as_tensor(perceived_spectrum.get_intensity(wavelengths.cpu().numpy(), display_type='iPadPro'), dtype=torch.float32, device=device).squeeze(-1).permute(1, 0)
68 target_image_rgb = torch.rand(3, 100, 100, device=device)
69 target_image_lms = convert_rgb_to_lms(target_image_rgb, target_spectrum)
70 optimized_primaries = optimize_primaries(
71 wavelengths,
72 target_image_rgb,
73 target_image_lms,
74 number_of_primaries=number_of_primaries,
75 p_count=4,
76 learning_rate=1e-6,
77 number_of_iterations=1000000,
78 device=device
79 )
80 plt.figure()
81 print(torch.argmax(optimized_primaries, dim=1) + wavelengths[0])
82 print(torch.argmax(target_spectrum, dim=1) + wavelengths[0])
83 for i in range(number_of_primaries):
84 plt.plot(optimized_primaries[i].cpu().numpy())
85 plt.figure()
86 for i in range(number_of_primaries):
87 plt.plot(target_spectrum[i].cpu().numpy())
88 plt.show()
89 return True
Here, the main
function is called to execute the optimization and spectrum generation.
91 if __name__ == '__main__':
92 main()