diff --git a/tpu/src/main/java/tpu/CreateTpuVmWithStartupScript.java b/tpu/src/main/java/tpu/CreateTpuVmWithStartupScript.java new file mode 100644 index 00000000000..e25fa04b5fd --- /dev/null +++ b/tpu/src/main/java/tpu/CreateTpuVmWithStartupScript.java @@ -0,0 +1,106 @@ +/* +* Copyright 2024 Google LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package tpu; + +//[START tpu_vm_create_startup_script] +import com.google.api.gax.longrunning.OperationTimedPollAlgorithm; +import com.google.api.gax.retrying.RetrySettings; +import com.google.cloud.tpu.v2.CreateNodeRequest; +import com.google.cloud.tpu.v2.Node; +import com.google.cloud.tpu.v2.TpuClient; +import com.google.cloud.tpu.v2.TpuSettings; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import org.threeten.bp.Duration; + +public class CreateTpuVmWithStartupScript { + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // Project ID or project number of the Google Cloud project you want to create a node. + String projectId = "YOUR_PROJECT_ID"; + // The zone in which to create the TPU. + // For more information about supported TPU types for specific zones, + // see https://cloud.google.com/tpu/docs/regions-zones + String zone = "europe-west4-a"; + // The name for your TPU. + String nodeName = "YOUR_TPU_NAME"; + // The accelerator type that specifies the version and size of the Cloud TPU you want to create. + // For more information about supported accelerator types for each TPU version, + // see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions. + String acceleratorType = "v2-8"; + // Software version that specifies the version of the TPU runtime to install. + // For more information, see https://cloud.google.com/tpu/docs/runtimes + String tpuSoftwareVersion = "tpu-vm-tf-2.14.1"; + + createTpuVmWithStartupScript(projectId, zone, nodeName, acceleratorType, tpuSoftwareVersion); + } + + // Create a TPU VM with a startup script. + public static Node createTpuVmWithStartupScript(String projectId, String zone, + String nodeName, String acceleratorType, String tpuSoftwareVersion) + throws IOException, ExecutionException, InterruptedException { + // With these settings the client library handles the Operation's polling mechanism + // and prevent CancellationException error + TpuSettings.Builder clientSettings = + TpuSettings.newBuilder(); + clientSettings + .createNodeOperationSettings() + .setPollingAlgorithm( + OperationTimedPollAlgorithm.create( + RetrySettings.newBuilder() + .setInitialRetryDelay(Duration.ofMillis(5000L)) + .setRetryDelayMultiplier(1.5) + .setMaxRetryDelay(Duration.ofMillis(45000L)) + .setInitialRpcTimeout(Duration.ZERO) + .setRpcTimeoutMultiplier(1.0) + .setMaxRpcTimeout(Duration.ZERO) + .setTotalTimeout(Duration.ofHours(24L)) + .build())); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (TpuClient tpuClient = TpuClient.create(clientSettings.build())) { + String parent = String.format("projects/%s/locations/%s", projectId, zone); + + String startupScriptContent = "#!/bin/bash\necho \"Hello from the startup script!\""; + // Add startup script to metadata + Map metadata = new HashMap<>(); + metadata.put("startup-script", startupScriptContent); + + Node tpuVm = + Node.newBuilder() + .setName(nodeName) + .setAcceleratorType(acceleratorType) + .setRuntimeVersion(tpuSoftwareVersion) + .putAllMetadata(metadata) + .build(); + + CreateNodeRequest request = + CreateNodeRequest.newBuilder() + .setParent(parent) + .setNodeId(nodeName) + .setNode(tpuVm) + .build(); + + return tpuClient.createNodeAsync(request).get(); + } + } +} +//[END tpu_vm_create_startup_script] \ No newline at end of file diff --git a/tpu/src/test/java/tpu/CreateTpuVmWithStartupScriptIT.java b/tpu/src/test/java/tpu/CreateTpuVmWithStartupScriptIT.java new file mode 100644 index 00000000000..c4b001f9e85 --- /dev/null +++ b/tpu/src/test/java/tpu/CreateTpuVmWithStartupScriptIT.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tpu; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.assertNotNull; + +import com.google.cloud.tpu.v2.Node; +import java.io.IOException; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import org.junit.Assert; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@Timeout(value = 6, unit = TimeUnit.MINUTES) +public class CreateTpuVmWithStartupScriptIT { + private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); + private static final String ZONE = "asia-east1-c"; + private static final String NODE_NAME = "test-tpu-with-script-" + UUID.randomUUID(); + private static final String TPU_TYPE = "v2-8"; + private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1"; + + public static void requireEnvVar(String envVarName) { + assertWithMessage(String.format("Missing environment variable '%s' ", envVarName)) + .that(System.getenv(envVarName)).isNotEmpty(); + } + + @BeforeAll + public static void setUp() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("GOOGLE_CLOUD_PROJECT"); + } + + @AfterAll + public static void cleanup() throws Exception { + DeleteTpuVm.deleteTpuVm(PROJECT_ID, ZONE, NODE_NAME); + } + + @Test + public void testCreateTpuVmWithStartupScript() + throws IOException, ExecutionException, InterruptedException { + Node node = CreateTpuVmWithStartupScript.createTpuVmWithStartupScript( + PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION); + + assertNotNull(node); + assertThat(node.getName().equals(NODE_NAME)); + Assert.assertTrue(node.containsMetadata("startup-script")); + Assert.assertTrue(node.getMetadataMap().containsValue("#!/bin/bash\n" + + "echo \"Hello from the startup script!\"")); + } +} \ No newline at end of file diff --git a/tpu/src/test/java/tpu/QueuedResourceIT.java b/tpu/src/test/java/tpu/QueuedResourceIT.java index a7dbba51ff4..f52a1dce640 100644 --- a/tpu/src/test/java/tpu/QueuedResourceIT.java +++ b/tpu/src/test/java/tpu/QueuedResourceIT.java @@ -32,6 +32,7 @@ @RunWith(JUnit4.class) @Timeout(value = 6, unit = TimeUnit.MINUTES) public class QueuedResourceIT { + private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); private static final String ZONE = "europe-west4-a"; private static final String NODE_NAME = "test-tpu-queued-resource-network-" + UUID.randomUUID(); diff --git a/tpu/src/test/java/tpu/TpuVmIT.java b/tpu/src/test/java/tpu/TpuVmIT.java index 761c1b1c5bd..79de922f4dd 100644 --- a/tpu/src/test/java/tpu/TpuVmIT.java +++ b/tpu/src/test/java/tpu/TpuVmIT.java @@ -73,7 +73,6 @@ public static void cleanup() throws Exception { @Test @Order(1) public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException { - Node node = CreateTpuVm.createTpuVm( PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);