Skip to content

Commit

Permalink
SH fix for spz (#16055)
Browse files Browse the repository at this point in the history
* SH support for raw ply

* forwarding SH to splats

* conversion with or without SH

* Fix up

* wrong method call for SH
  • Loading branch information
CedricGuillemet authored Jan 14, 2025
1 parent 40902f0 commit 26d7e7b
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { ToHalfFloat } from "core/Misc/textureTools";
import type { Material } from "core/Materials/material";
import { Scalar } from "core/Maths/math.scalar";
import { runCoroutineSync, runCoroutineAsync, createYieldingScheduler, type Coroutine } from "core/Misc/coroutine";
import { EngineStore } from "core/Engines/engineStore";

interface DelayedTextureUpdate {
covA: Uint16Array;
Expand Down Expand Up @@ -83,6 +84,11 @@ interface CompressedPLYChunk {
maxColor: Vector3;
}

// @internal
interface PLYConversionBuffers {
buffer: ArrayBuffer;
sh?: [];
}
/**
* Representation of the types
*/
Expand Down Expand Up @@ -148,6 +154,52 @@ const enum PLYValue {
MAX_COLOR_G,
MAX_COLOR_B,

SH_0,
SH_1,
SH_2,
SH_3,
SH_4,
SH_5,
SH_6,
SH_7,
SH_8,
SH_9,
SH_10,
SH_11,
SH_12,
SH_13,
SH_14,
SH_15,
SH_16,
SH_17,
SH_18,
SH_19,
SH_20,
SH_21,
SH_22,
SH_23,
SH_24,
SH_25,
SH_26,
SH_27,
SH_28,
SH_29,
SH_30,
SH_31,
SH_32,
SH_33,
SH_34,
SH_35,
SH_36,
SH_37,
SH_38,
SH_39,
SH_40,
SH_41,
SH_42,
SH_43,
SH_44,

UNDEFINED,
}

Expand Down Expand Up @@ -205,6 +257,18 @@ export interface PLYHeader {
* buffer for the data view
*/
buffer: ArrayBuffer;
/**
* degree of SH coefficients
*/
shDegree: number;
/**
* number of coefficient per splat
*/
shCoefficientCount: number;
/**
* buffer for SH coefficients
*/
shBuffer: ArrayBuffer | null;
}
/**
* Class used to render a gaussian splatting mesh
Expand Down Expand Up @@ -524,6 +588,96 @@ export class GaussianSplattingMesh extends Mesh {
return PLYValue.MAX_COLOR_G;
case "max_b":
return PLYValue.MAX_COLOR_B;
case "f_rest_0":
return PLYValue.SH_0;
case "f_rest_1":
return PLYValue.SH_1;
case "f_rest_2":
return PLYValue.SH_2;
case "f_rest_3":
return PLYValue.SH_3;
case "f_rest_4":
return PLYValue.SH_4;
case "f_rest_5":
return PLYValue.SH_5;
case "f_rest_6":
return PLYValue.SH_6;
case "f_rest_7":
return PLYValue.SH_7;
case "f_rest_8":
return PLYValue.SH_8;
case "f_rest_9":
return PLYValue.SH_9;
case "f_rest_10":
return PLYValue.SH_10;
case "f_rest_11":
return PLYValue.SH_11;
case "f_rest_12":
return PLYValue.SH_12;
case "f_rest_13":
return PLYValue.SH_13;
case "f_rest_14":
return PLYValue.SH_14;
case "f_rest_15":
return PLYValue.SH_15;
case "f_rest_16":
return PLYValue.SH_16;
case "f_rest_17":
return PLYValue.SH_17;
case "f_rest_18":
return PLYValue.SH_18;
case "f_rest_19":
return PLYValue.SH_19;
case "f_rest_20":
return PLYValue.SH_20;
case "f_rest_21":
return PLYValue.SH_21;
case "f_rest_22":
return PLYValue.SH_22;
case "f_rest_23":
return PLYValue.SH_23;
case "f_rest_24":
return PLYValue.SH_24;
case "f_rest_25":
return PLYValue.SH_25;
case "f_rest_26":
return PLYValue.SH_26;
case "f_rest_27":
return PLYValue.SH_27;
case "f_rest_28":
return PLYValue.SH_28;
case "f_rest_29":
return PLYValue.SH_29;
case "f_rest_30":
return PLYValue.SH_30;
case "f_rest_31":
return PLYValue.SH_31;
case "f_rest_32":
return PLYValue.SH_32;
case "f_rest_33":
return PLYValue.SH_33;
case "f_rest_34":
return PLYValue.SH_34;
case "f_rest_35":
return PLYValue.SH_35;
case "f_rest_36":
return PLYValue.SH_36;
case "f_rest_37":
return PLYValue.SH_37;
case "f_rest_38":
return PLYValue.SH_38;
case "f_rest_39":
return PLYValue.SH_39;
case "f_rest_40":
return PLYValue.SH_40;
case "f_rest_41":
return PLYValue.SH_41;
case "f_rest_42":
return PLYValue.SH_42;
case "f_rest_43":
return PLYValue.SH_43;
case "f_rest_44":
return PLYValue.SH_44;
}

return PLYValue.UNDEFINED;
Expand Down Expand Up @@ -569,11 +723,20 @@ export class GaussianSplattingMesh extends Mesh {
const vertexProperties: PlyProperty[] = [];
const chunkProperties: PlyProperty[] = [];
const filtered = header.slice(0, headerEndIndex).split("\n");
let shDegree = 0;
for (const prop of filtered) {
if (prop.startsWith("property ")) {
const [, typeName, name] = prop.split(" ");

const value = GaussianSplattingMesh._ValueNameToEnum(name);
// SH degree 1,2 or 3 for 9, 24 or 45 values
if (value >= PLYValue.SH_44) {
shDegree = 3;
} else if (value >= PLYValue.SH_24) {
shDegree = 2;
} else if (value >= PLYValue.SH_8) {
shDegree = 1;
}
const type = GaussianSplattingMesh._TypeNameToEnum(typeName);
if (chunkMode == ElementMode.Chunk) {
chunkProperties.push({ value, type, offset: rowChunkOffset });
Expand All @@ -599,6 +762,14 @@ export class GaussianSplattingMesh extends Mesh {
const dataView = new DataView(data, headerEndIndex + headerEnd.length);
const buffer = new ArrayBuffer(GaussianSplattingMesh._RowOutputLength * vertexCount);

let shBuffer = null;
let shCoefficientCount = 0;
if (shDegree) {
const shVectorCount = (shDegree + 1) * (shDegree + 1) - 1;
shCoefficientCount = shVectorCount * 3;
shBuffer = new ArrayBuffer(shCoefficientCount * vertexCount);
}

return {
vertexCount: vertexCount,
chunkCount: chunkCount,
Expand All @@ -608,6 +779,9 @@ export class GaussianSplattingMesh extends Mesh {
chunkProperties: chunkProperties,
dataView: dataView,
buffer: buffer,
shDegree: shDegree,
shCoefficientCount: shCoefficientCount,
shBuffer: shBuffer,
};
}
private static _GetCompressedChunks(header: PLYHeader, offset: { value: number }): Array<CompressedPLYChunk> | null {
Expand Down Expand Up @@ -710,6 +884,10 @@ export class GaussianSplattingMesh extends Mesh {
const scale = new Float32Array(buffer, index * rowOutputLength + 12, 3);
const rgba = new Uint8ClampedArray(buffer, index * rowOutputLength + 24, 4);
const rot = new Uint8ClampedArray(buffer, index * rowOutputLength + 28, 4);
let sh = null;
if (header.shBuffer) {
sh = new Uint8ClampedArray(header.shBuffer, index * header.shCoefficientCount, header.shCoefficientCount);
}
const chunkIndex = index >> 8;
let r0: number = 255;
let r1: number = 0;
Expand Down Expand Up @@ -831,6 +1009,11 @@ export class GaussianSplattingMesh extends Mesh {
r3 = value;
break;
}
if (sh && property.value >= PLYValue.SH_0 && property.value <= PLYValue.SH_44) {
const clampedValue = Scalar.Clamp(value * 127.5 + 127.5, 0, 255);
const shIndex = property.value - PLYValue.SH_0;
sh[shIndex] = clampedValue;
}
}

q.set(r1, r2, r3, r0);
Expand All @@ -842,12 +1025,74 @@ export class GaussianSplattingMesh extends Mesh {
offset.value += header.rowVertexLength;
}

/**
* Converts a .ply data with SH coefficients splat
* if data array buffer is not ply, returns the original buffer
* @param data the .ply data to load
* @param useCoroutine use coroutine and yield
* @returns the loaded splat buffer and optional array of sh coefficients
*/
public static *ConvertPLYWithSHToSplat(data: ArrayBuffer, useCoroutine = false) {
const header = GaussianSplattingMesh.ParseHeader(data);
if (!header) {
return { buffer: data };
}

const offset = { value: 0 };
const compressedChunks = GaussianSplattingMesh._GetCompressedChunks(header, offset);

for (let i = 0; i < header.vertexCount; i++) {
GaussianSplattingMesh._GetSplat(header, i, compressedChunks, offset);
if (i % GaussianSplattingMesh._PlyConversionBatchSize === 0 && useCoroutine) {
yield;
}
}

let sh = null;
// make SH texture buffers
if (header.shDegree && header.shBuffer) {
const textureCount = Math.ceil(header.shCoefficientCount / 16); // 4 components can be stored per texture, 4 sh per component
let shIndexRead = 0;
const ubuf = new Uint8Array(header.shBuffer);

// sh is an array of uint8array that will be used to create sh textures
sh = [];

const splatCount = header.vertexCount;
const engine = EngineStore.LastCreatedEngine;
if (engine) {
const width = engine.getCaps().maxTextureSize;
const height = Math.ceil(splatCount / width);
// create array for the number of textures needed.
for (let textureIndex = 0; textureIndex < textureCount; textureIndex++) {
const texture = new Uint8Array(height * width * 4 * 4); // 4 components per texture, 4 sh per component
sh.push(texture);
}

for (let i = 0; i < splatCount; i++) {
for (let shIndexWrite = 0; shIndexWrite < header.shCoefficientCount; shIndexWrite++) {
const shValue = ubuf[shIndexRead++];

const textureIndex = Math.floor(shIndexWrite / 16);
const shArray = sh[textureIndex];

const byteIndexInTexture = shIndexWrite % 16; // [0..15]
const offsetPerSplat = i * 16; // 16 sh values per texture per splat.
shArray[byteIndexInTexture + offsetPerSplat] = shValue;
}
}
}
}

return { buffer: header.buffer, sh: sh };
}

/**
* Converts a .ply data array buffer to splat
* if data array buffer is not ply, returns the original buffer
* @param data the .ply data to load
* @param useCoroutine use coroutine and yield
* @returns the loaded splat buffer
* @returns the loaded splat buffer without SH coefficient, whether ply contains or not SH.
*/
public static *ConvertPLYToSplat(data: ArrayBuffer, useCoroutine = false) {
const header = GaussianSplattingMesh.ParseHeader(data);
Expand Down Expand Up @@ -878,6 +1123,15 @@ export class GaussianSplattingMesh extends Mesh {
return runCoroutineAsync(GaussianSplattingMesh.ConvertPLYToSplat(data, true), createYieldingScheduler());
}

/**
* Converts a .ply with SH data array buffer to splat
* if data array buffer is not ply, returns the original buffer
* @param data the .ply data to load
* @returns the loaded splat buffer with SH
*/
public static async ConvertPLYWithSHToSplatAsync(data: ArrayBuffer) {
return runCoroutineAsync(GaussianSplattingMesh.ConvertPLYWithSHToSplat(data, true), createYieldingScheduler());
}
/**
* Loads a .splat Gaussian Splatting array buffer asynchronously
* @param data arraybuffer containing splat file
Expand All @@ -896,8 +1150,8 @@ export class GaussianSplattingMesh extends Mesh {
*/
public loadFileAsync(url: string): Promise<void> {
return Tools.LoadFileAsync(url, true).then(async (plyBuffer) => {
GaussianSplattingMesh.ConvertPLYToSplatAsync(plyBuffer).then((splatsData) => {
this.updateDataAsync(splatsData);
(GaussianSplattingMesh.ConvertPLYWithSHToSplatAsync(plyBuffer) as any).then((splatsData: PLYConversionBuffers) => {
this.updateDataAsync(splatsData.buffer, splatsData.sh);
});
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,19 @@ vec3 computeColorFromSHDegree(vec3 dir, const vec3 sh[16])
result +=
SH_C2[0] * xy * sh[4] +
SH_C2[1] * yz * sh[5] +
SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
SH_C2[2] * (2.0 * zz - xx - yy) * sh[6] +
SH_C2[3] * xz * sh[7] +
SH_C2[4] * (xx - yy) * sh[8];

#if SH_DEGREE > 2
result +=
SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
SH_C3[0] * y * (3.0 * xx - yy) * sh[9] +
SH_C3[1] * xy * z * sh[10] +
SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
SH_C3[2] * y * (4.0 * zz - xx - yy) * sh[11] +
SH_C3[3] * z * (2.0 * zz - 3.0 * xx - 3.0 * yy) * sh[12] +
SH_C3[4] * x * (4.0 * zz - xx - yy) * sh[13] +
SH_C3[5] * z * (xx - yy) * sh[14] +
SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
SH_C3[6] * x * (xx - 3.0 * yy) * sh[15];
#endif
#endif
#endif
Expand Down
1 change: 1 addition & 0 deletions packages/dev/core/src/Shaders/gaussianSplatting.vertex.fx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void main () {

#if SH_DEGREE > 0
vec3 dir = normalize(worldPos.xyz - vEyePosition.xyz);
dir.y *= -1.; // Up is inverted. This corresponds to change in _makeSplat method
vColor.xyz = computeSH(splat, splat.color.xyz, dir);
#endif

Expand Down
Loading

0 comments on commit 26d7e7b

Please sign in to comment.