Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,12 +2604,12 @@ def submit(
or restart_job_on_worker_restart
or disable_retries
or scheduling_strategy
or max_wait_duration
or max_wait_duration is not None # 0 is a valid value
):
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
max_wait_duration = (
duration_pb2.Duration(seconds=max_wait_duration)
if max_wait_duration
if max_wait_duration is not None
else None
)
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
Expand Down Expand Up @@ -3133,13 +3133,13 @@ def _run(
timeout
or restart_job_on_worker_restart
or disable_retries
or max_wait_duration
or max_wait_duration is not None # 0 is a valid value
or scheduling_strategy
):
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
max_wait_duration = (
duration_pb2.Duration(seconds=max_wait_duration)
if max_wait_duration
if max_wait_duration is not None
else None
)
self._gca_resource.trial_job_spec.scheduling = (
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#

import datetime
import pytest
import logging

Expand Down Expand Up @@ -730,6 +731,101 @@ def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock):
)
assert job.network == _TEST_NETWORK

def test_submit_custom_job_with_zero_max_wait_duration(
self, create_custom_job_mock, get_custom_job_mock
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
max_wait_duration=0,
)

job.wait_for_resource_creation()

assert job.resource_name == _TEST_CUSTOM_JOB_NAME

job.wait()

expected_custom_job = _get_custom_job_proto()
expected_custom_job.job_spec.scheduling.max_wait_duration = datetime.timedelta(
seconds=0
)

create_custom_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
custom_job=expected_custom_job,
timeout=None,
)
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
)

def test_submit_custom_job_with_default_max_wait_duration(
self, create_custom_job_mock, get_custom_job_mock
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
)

job.wait_for_resource_creation()

assert job.resource_name == _TEST_CUSTOM_JOB_NAME

job.wait()

expected_custom_job = _get_custom_job_proto()
expected_custom_job.job_spec.scheduling.max_wait_duration = None

create_custom_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
custom_job=expected_custom_job,
timeout=None,
)

assert "max_wait_duration" not in expected_custom_job.job_spec.scheduling
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
)

@pytest.mark.usefixtures(
"get_experiment_run_mock", "get_tensorboard_run_artifact_not_found_mock"
)
Expand Down
149 changes: 148 additions & 1 deletion tests/unit/aiplatform/test_hyperparameter_tuning_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# limitations under the License.
#

import copy
import datetime
import pytest

import copy
from importlib import reload
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -523,6 +524,152 @@ def test_create_hyperparameter_tuning_job(
assert job.network == _TEST_NETWORK
assert job.trials == []

def test_create_hyperparameter_tuning_job_with_zero_max_wait_duration(
self,
create_hyperparameter_tuning_job_mock,
get_hyperparameter_tuning_job_mock,
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

custom_job = aiplatform.CustomJob(
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
)

job = aiplatform.HyperparameterTuningJob(
display_name=_TEST_DISPLAY_NAME,
custom_job=custom_job,
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
parameter_spec={
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
"activation": hpt.CategoricalParameterSpec(
values=["relu", "sigmoid", "elu", "selu", "tanh"]
),
"batch_size": hpt.DiscreteParameterSpec(
values=[4, 8, 16, 32, 64],
scale="linear",
conditional_parameter_spec={
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
},
),
},
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
max_trial_count=_TEST_MAX_TRIAL_COUNT,
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
search_algorithm=_TEST_SEARCH_ALGORITHM,
measurement_selection=_TEST_MEASUREMENT_SELECTION,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=True,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
max_wait_duration=0,
)

job.wait()

expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = datetime.timedelta(
seconds=0
)

create_hyperparameter_tuning_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
timeout=None,
)
assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED

def test_create_hyperparameter_tuning_job_with_default_max_wait_duration(
self,
create_hyperparameter_tuning_job_mock,
get_hyperparameter_tuning_job_mock,
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

custom_job = aiplatform.CustomJob(
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
)

job = aiplatform.HyperparameterTuningJob(
display_name=_TEST_DISPLAY_NAME,
custom_job=custom_job,
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
parameter_spec={
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
"activation": hpt.CategoricalParameterSpec(
values=["relu", "sigmoid", "elu", "selu", "tanh"]
),
"batch_size": hpt.DiscreteParameterSpec(
values=[4, 8, 16, 32, 64],
scale="linear",
conditional_parameter_spec={
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
},
),
},
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
max_trial_count=_TEST_MAX_TRIAL_COUNT,
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
search_algorithm=_TEST_SEARCH_ALGORITHM,
measurement_selection=_TEST_MEASUREMENT_SELECTION,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=True,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
)

job.wait()

expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
expected_hyperparameter_tuning_job.trial_job_spec.scheduling.max_wait_duration = (
None
)

create_hyperparameter_tuning_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
timeout=None,
)

assert (
"max_wait_duration"
not in expected_hyperparameter_tuning_job.trial_job_spec.scheduling
)
assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED

@pytest.mark.parametrize("sync", [True, False])
def test_create_hyperparameter_tuning_job_with_timeout(
self,
Expand Down
Loading