import torch from torch import nn from torch.nn import functional as F from typing import Any, Dict, List, Tuple, Type, Optional from functools import partial import numpy as np import math from transformers import AutoImageProcessor, AutoModel # from ._registry import register_model # from ._pretrain import load_pretrained_weights def window_partition3D(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, D, H, W, C = x.shape pad_d = (window_size - D % window_size) % window_size pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0 or pad_d > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_d)) Hp, Wp, Dp = H + pad_h, W + pad_w, D + pad_d x = x.view(B, Dp // window_size, window_size, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) return windows, (Dp, Hp, Wp) def window_unpartition3D(windows: torch.Tensor, window_size: int, pad_dhw: Tuple[int, int, int], dhw: Tuple[int, int, int]) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Dp, Hp, Wp = pad_dhw D, H, W = dhw B = windows.shape[0] // (Dp * Hp * Wp // window_size // window_size // window_size) x = windows.view(B, Dp // window_size, Hp // window_size, Wp // window_size, window_size, window_size, window_size, -1) x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Dp, Hp, Wp, -1) if Hp > H or Wp > W or Dp > D: x = x[:, :D, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_d: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int, int], k_size: Tuple[int, int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_d, q_h, q_w = q_size k_d, k_h, k_w = k_size Rd = get_rel_pos(q_d, k_d, rel_pos_d) Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_d, q_h, q_w, dim) rel_d = torch.einsum("bdhwc,dkc->bdhwk", r_q, Rd) rel_h = torch.einsum("bdhwc,hkc->bdhwk", r_q, Rh) rel_w = torch.einsum("bdhwc,wkc->bdhwk", r_q, Rw) attn = (attn.view(B, q_d, q_h, q_w, k_d, k_h, k_w) + rel_d[:, :, :, :, None, None] + rel_h[:, :, :, None, :, None] + rel_w[:, :, :, None, None, :]).view( B, q_d * q_h * q_w, k_d * k_h * k_w) return attn class PatchEmbed3D(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16, 16), stride: Tuple[int, int] = (16, 16, 16), padding: Tuple[int, int] = (0, 0, 0), in_chans: int = 1, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C X Y Z -> B X Y Z C x = x.permute(0, 2, 3, 4, 1) return x class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) self.sigmoid_output = sigmoid_output def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_d = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[2] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, D, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, D * H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, D * H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_d, self.rel_pos_h, self.rel_pos_w, (D, H, W), (D, H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, D, H, W, -1).permute(0, 2, 3, 4, 1, 5).reshape(B, D, H, W, -1) x = self.proj(x) return x class LayerNorm3d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] return x class Block3D(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: D, H, W = x.shape[1], x.shape[2], x.shape[3] x, pad_dhw = window_partition3D(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition3D(x, self.window_size, pad_dhw, (D, H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT3D(nn.Module): def __init__( self, img_size: int = 256, patch_size: int = 16, in_chans: int = 1, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size self.patch_embed = PatchEmbed3D( kernel_size=(patch_size, patch_size, patch_size), stride=(patch_size, patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, img_size // patch_size, embed_dim)) self.blocks = nn.ModuleList() for i in range(depth): block = Block3D( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv3d( embed_dim, out_chans, kernel_size=1, bias=False, ), # nn.LayerNorm(out_chans), LayerNorm3d(out_chans), nn.Conv3d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm3d(out_chans), # nn.LayerNorm(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # input_size = [1,1,256,256,256] # import IPython; IPython.embed() x = self.patch_embed(x) # x = [1,16,16,16,768] # import pdb; pdb.set_trace() if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) # x = [1,16,16,16,768] # x = self.neck(x.permute(0, 4, 1, 2, 3)) # output_size = [1,256,16,16,16] return x class TwoWayTransformer3D(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() for i in range(depth): self.layers.append( TwoWayAttentionBlock3D( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), )) self.final_attn_token_to_image = DownscaleAttention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, x, y, z = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) return queries, keys class DownscaleAttention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Attention _, _, _, c_per_head = q.shape attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / math.sqrt(c_per_head) attn = torch.softmax(attn, dim=-1) # Get output out = attn @ v out = self._recombine_heads(out) out = self.out_proj(out) return out class TwoWayAttentionBlock3D(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = DownscaleAttention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = DownscaleAttention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = DownscaleAttention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.skip_first_layer_pe = skip_first_layer_pe def forward(self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys class MaskDecoder3D(nn.Module): def __init__( self, *, transformer_dim: int, # transformer: nn.Module , num_multimask_outputs: int = 3, activation: Type[nn.Module] = nn.GELU, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256, ) -> None: """ Predicts masks given an image and prompt embeddings, using a transformer architecture. Arguments: transformer_dim (int): the channel dimension of the transformer transformer (nn.Module): the transformer used to predict masks num_multimask_outputs (int): the number of masks to predict when disambiguating masks activation (nn.Module): the type of activation to use when upscaling masks iou_head_depth (int): the depth of the MLP used to predict mask quality iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality """ super().__init__() self.transformer_dim = transformer_dim # self.transformer = transformer self.transformer = TwoWayTransformer3D( depth=2, embedding_dim=self.transformer_dim, mlp_dim=2048, num_heads=8, ) self.num_multimask_outputs = num_multimask_outputs self.iou_token = nn.Embedding(1, transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) self.output_upscaling = nn.Sequential( nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), LayerNorm3d(transformer_dim // 4), activation(), nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), activation(), ) self.output_hypernetworks_mlps = nn.ModuleList([ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens) ]) self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ masks, iou_pred = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, ) # Select the correct mask or masks for output if multimask_output: mask_slice = slice(1, None) else: mask_slice = slice(0, 1) masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice] # Prepare output return masks, iou_pred def predict_masks( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask if image_embeddings.shape[0] != tokens.shape[0]: src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) else: src = image_embeddings src = src + dense_prompt_embeddings if image_pe.shape[0] != tokens.shape[0]: pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) else: pos_src = image_pe b, c, x, y, z = src.shape # Run the transformer # import IPython; IPython.embed() hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, x, y, z) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, x, y, z = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, x * y * z)).view(b, -1, x, y, z) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred class PositionEmbeddingRandom3D(nn.Module): """ Positional encoding using random spatial frequencies. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.register_buffer( "positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats)), ) def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords), torch.sin(coords)], dim=-1) def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" x, y, z = size device: Any = self.positional_encoding_gaussian_matrix.device grid = torch.ones((x, y, z), device=device, dtype=torch.float32) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 z_embed = grid.cumsum(dim=2) - 0.5 y_embed = y_embed / y x_embed = x_embed / x z_embed = z_embed / z pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) return pe.permute(3, 0, 1, 2) # C x X x Y x Z def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int, int]) -> torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[0] coords[:, :, 1] = coords[:, :, 1] / image_size[1] coords[:, :, 2] = coords[:, :, 2] / image_size[2] return self._pe_encoding(coords.to(torch.float)) # B x N x C class PromptEncoder3D(nn.Module): def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int, int], input_image_size: Tuple[int, int, int], mask_in_chans: int, activation: Type[nn.Module] = nn.GELU, ) -> None: """ Encodes prompts for input to SAM's mask decoder. Arguments: embed_dim (int): The prompts' embedding dimension image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W). input_image_size (int): The padded size of the image as input to the image encoder, as (H, W). mask_in_chans (int): The number of hidden channels used for encoding input masks. activation (nn.Module): The activation to use when encoding input masks. """ super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size self.pe_layer = PositionEmbeddingRandom3D(embed_dim // 3) self.num_point_embeddings: int = 2 # pos/neg point point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) self.mask_input_size = (image_embedding_size[0], image_embedding_size[1], image_embedding_size[2]) self.mask_downscaling = nn.Sequential( nn.Conv3d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm3d(mask_in_chans // 4), activation(), nn.Conv3d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), LayerNorm3d(mask_in_chans), activation(), nn.Conv3d(mask_in_chans, embed_dim, kernel_size=1), ) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> torch.Tensor: """ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) # 1xXxYxZ def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def _get_batch_size( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ if points is not None: return points[0].shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] else: return 1 def _get_device(self) -> torch.device: return self.point_embeddings[0].weight.device def forward( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2]) return sparse_embeddings, dense_embeddings class E3D(nn.Module): mask_threshold: float = 0.0 image_format: str = "L" def __init__( self, image_encoder: ImageEncoderViT3D, prompt_encoder: PromptEncoder3D, mask_decoder: MaskDecoder3D, pixel_mean: List[float] = [123.675], pixel_std: List[float] = [58.395], ) -> None: """ E3D predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property def device(self) -> Any: return self.pixel_mean.device @torch.no_grad() def forward( self, batched_input: List[Dict[str, Any]], ) -> List[Dict[str, torch.Tensor]]: """ Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly. Arguments: batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt key can be excluded if it is not present. 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model. 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W). 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already transformed to the input frame of the model. 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of the model. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW. Returns: (list(dict)): A list over input images, where each element is as dictionary with the following keys. 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC. 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction. """ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) image_embeddings = self.image_encoder(input_images) outputs = [] for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) masks = self.postprocess_masks( low_res_masks, input_size=image_record["image"].shape[-3:], original_size=image_record["original_size"], ) masks = masks > self.mask_threshold outputs.append({ "masks": masks, "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, }) return outputs def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. Arguments: masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. original_size (tuple(int, int)): The original size of the image before resizing for input to the model, in (H, W) format. Returns: (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) masks = masks[..., :input_size[0], :input_size[1], :input_size[2]] masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - self.pixel_mean) / self.pixel_std # Pad d, h, w = x.shape[-3:] padd = self.image_encoder.img_size - d padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh, 0, padd)) return x class S3dClassifier(nn.Module): def __init__(self): super(S3dClassifier, self).__init__() self.encoder_3d = ImageEncoderViT3D( img_size=128, patch_size=16, in_chans=1, out_chans = 384, num_heads = 12, embed_dim=768, depth=12, qkv_bias=True, use_rel_pos=True, window_size=14, mlp_ratio=4.0, global_attn_indexes = [2, 5, 8, 11], norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), ) self.relu = nn.GELU() self.linear_1 = nn.Linear(768*8*8*8, 768) self.linear_2 = nn.Linear(768, 768*10) self.linear_3 = nn.Linear(768*10, 300) self.classifier = nn.Linear(300, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.encoder_3d(x) x = x.reshape(x.size(0), -1) x = self.linear_1(x) x = self.relu(x) x = self.linear_2(x) x = self.relu(x) x = self.linear_3(x) x = self.relu(x) x = self.classifier(x) x = self.sigmoid(x) x = x.squeeze(1) return x class D2dClassifier(nn.Module): def __init__(self, pretrain_dir=None): super(D2dClassifier, self).__init__() self.dino = AutoModel.from_pretrained(pretrain_dir) self.relu = nn.GELU() self.linear_1 = nn.Linear(401*768, 768) self.linear_2 = nn.Linear(768, 768*10) self.linear_3 = nn.Linear(768*10, 300) self.classifier = nn.Linear(300, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.dino(x)[0] x = x.reshape(x.size(0), -1) x = self.linear_1(x) x = self.relu(x) x = self.linear_2(x) x = self.relu(x) x = self.linear_3(x) x = self.relu(x) x = self.classifier(x) x = self.sigmoid(x) x = x.squeeze(1) return x