|
- # Copyright (c) Meta Platforms, Inc. All Rights Reserved
-
- import torch
- import numpy as np
-
-
- class DistinctColors:
- def __init__(self):
- colors = [
- '#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f55031', '#911eb4', '#42d4f4', '#bfef45', '#fabed4', '#469990',
- '#dcb1ff', '#404E55', '#fffac8', '#809900', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#a9a9a9', '#f032e6',
- '#806020', '#ffffff',
-
- "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0030ED", "#3A2465", "#34362D", "#B4A8BD", "#0086AA",
- "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81", "#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700",
-
- "#04F757", "#C8A1A1", "#1E6E00",
- "#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700",
- "#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329",
- "#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C",
- ]
- self.hex_colors = colors
- # 0 = crimson / red, 1 = green, 2 = yellow, 3 = blue
- # 4 = orange, 5 = purple, 6 = sky blue, 7 = lime green
- self.colors = [hex_to_rgb(c) for c in colors]
- self.color_assignments = {}
- self.color_ctr = 0
- self.fast_color_index = torch.from_numpy(np.array([hex_to_rgb(colors[i % len(colors)]) for i in range(8096)] + [hex_to_rgb('#000000')]))
-
- def get_color(self, index, override_color_0=False):
- colors = [x for x in self.hex_colors]
- if override_color_0:
- colors[0] = "#3f3f3f"
- colors = [hex_to_rgb(c) for c in colors]
- if index not in self.color_assignments:
- self.color_assignments[index] = colors[self.color_ctr % len(self.colors)]
- self.color_ctr += 1
- return self.color_assignments[index]
-
- def get_color_fast_torch(self,
- index): # ndarray, (H, W)
- return self.fast_color_index[index]
-
- def get_color_fast_numpy(self, index, override_color_0=False):
- index = np.array(index).astype(np.int32)
- if override_color_0:
- colors = [x for x in self.hex_colors]
- colors[0] = "#3f3f3f"
- fast_color_index = torch.from_numpy(np.array([hex_to_rgb(colors[i % len(colors)]) for i in range(8096)] + [hex_to_rgb('#000000')]))
- return fast_color_index[index % fast_color_index.shape[0]].numpy()
- else:
- return self.fast_color_index[index % self.fast_color_index.shape[0]].numpy()
-
- def apply_colors(self, arr):
- out_arr = torch.zeros([arr.shape[0], 3])
-
- for i in range(arr.shape[0]):
- out_arr[i, :] = torch.tensor(self.get_color(arr[i].item()))
- return out_arr
-
- def apply_colors_fast_torch(self, arr):
- return self.fast_color_index[arr % self.fast_color_index.shape[0]]
-
- def apply_colors_fast_numpy(self, arr):
- return self.fast_color_index.numpy()[arr % self.fast_color_index.shape[0]]
-
-
- def hex_to_rgb(x):
- return [int(x[i:i + 2], 16) / 255 for i in (1, 3, 5)]
-
-
- def visualize_distinct_colors(num_vis=32):
- from PIL import Image
- dc = DistinctColors()
- labels = np.ones((1, 64, 64)).astype(np.int)
- all_labels = []
- for i in range(num_vis):
- all_labels.append(labels * i)
- all_labels = np.concatenate(all_labels, 0)
- shape = all_labels.shape
- labels_colored = dc.get_color_fast_numpy(all_labels.reshape(-1))
- labels_colored = (labels_colored.reshape(shape[0] * shape[1], shape[2], 3) * 255).astype(np.uint8)
- Image.fromarray(labels_colored).save("colormap.png")
-
- #
- # if __name__ == "__main__":
- # visualize_distinct_colors()
|