Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 62 additions & 20 deletions vertexai/_genai/sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ def _CreateAgentEngineSandboxConfig_to_vertex(
if getv(from_object, ["ttl"]) is not None:
setv(parent_object, ["ttl"], getv(from_object, ["ttl"]))

if getv(from_object, ["sandbox_environment_template"]) is not None:
setv(
parent_object,
["sandboxEnvironmentTemplate"],
getv(from_object, ["sandbox_environment_template"]),
)

if getv(from_object, ["sandbox_environment_snapshot"]) is not None:
setv(
parent_object,
["sandboxEnvironmentSnapshot"],
getv(from_object, ["sandbox_environment_snapshot"]),
)

if getv(from_object, ["owner"]) is not None:
setv(parent_object, ["owner"], getv(from_object, ["owner"]))

return to_object


Expand Down Expand Up @@ -820,7 +837,7 @@ def delete(
def generate_access_token(
self,
service_account_email: str,
sandbox_id: str,
sandbox_hostname: str,
port: str = "8080",
timeout: int = 3600,
) -> str:
Expand All @@ -829,8 +846,8 @@ def generate_access_token(
Args:
service_account_email (str):
Required. The email of the service account to use for signing.
sandbox_id (str):
Required. The resource name of the sandbox to generate a token for.
sandbox_hostname (str):
Required. The hostname of the sandbox to generate a token for.
port (str):
Optional. The port to use for the token. Defaults to "8080".
timeout (int):
Expand All @@ -841,13 +858,14 @@ def generate_access_token(
"""
client = iam_credentials_v1.IAMCredentialsClient()
name = f"projects/-/serviceAccounts/{service_account_email}"
custom_claims = {"port": port, "sandbox_id": sandbox_id}
custom_claims = {"hostname": sandbox_hostname, "port": port}
payload = {
"iat": int(time.time()),
"exp": int(time.time()) + timeout,
"iss": service_account_email,
"sub": service_account_email,
"nonce": secrets.randbelow(1000000000) + 1,
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
"aud": "https://aiplatform.googleapis.com/", # default audience for sandbox proxy
**custom_claims,
}
request = iam_credentials_v1.SignJwtRequest(
Expand All @@ -862,7 +880,9 @@ def send_command(
*,
http_method: str,
access_token: str,
routing_token: str,
sandbox_environment: types.SandboxEnvironment,
port: str = "8080",
path: Optional[str] = None,
query_params: Optional[dict[str, object]] = None,
headers: Optional[dict[str, str]] = None,
Expand All @@ -875,8 +895,12 @@ def send_command(
Required. The HTTP method to use for the command.
access_token (str):
Required. The access token to use for authorization.
routing_token (str):
Required. The routing token to use for authorization. This can be found in the sandbox environment's connection_info.
sandbox_environment (types.SandboxEnvironment):
Required. The sandbox environment to send the command to.
port (str):
Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation.
path (str):
Optional. The path to send the command to.
query_params (dict[str, object]):
Expand Down Expand Up @@ -905,6 +929,8 @@ def send_command(
if query_params:
path = f"{path}?{urlencode(query_params)}"
headers["Authorization"] = f"Bearer {access_token}"
headers["X-Sandbox-Routing-Token"] = routing_token
headers["X-Sandbox-Port"] = port
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
http_client = genai.Client(vertexai=True, http_options=http_options)
Expand All @@ -920,6 +946,8 @@ def generate_browser_ws_headers(
self,
sandbox_environment: types.SandboxEnvironment,
service_account_email: str,
routing_token: str,
port: str = "8080",
timeout: int = 3600,
) -> tuple[str, dict[str, str]]:
"""Generates the websocket upgrade headers for the browser.
Expand All @@ -929,47 +957,61 @@ def generate_browser_ws_headers(
Required. The sandbox environment to generate websocket headers for.
service_account_email (str):
Required. The email of the service account to use for signing.
routing_token (str):
Required. The routing token to use for authorization. This can be
found in the sandbox environment's connection_info.
port (str):
Optional. The port to use for the token. Defaults to "8080". This
should be one of the ports specified during template creation.
timeout (int):
Optional. The timeout in seconds for the token. Defaults to 3600.

Returns:
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
the headers for websocket upgrade.
"""
sandbox_id = sandbox_environment.name
if not sandbox_environment.connection_info:
raise ValueError("Connection info is not available.")

ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
connection_info = sandbox_environment.connection_info
if connection_info.load_balancer_hostname:
ws_base_url = "wss://" + connection_info.load_balancer_hostname
elif connection_info.load_balancer_ip:
ws_base_url = "ws://" + connection_info.load_balancer_ip
else:
raise ValueError("Load balancer hostname or ip is not available.")

# port 8080 is the default port for http endpoint.
http_access_token = self.generate_access_token(
service_account_email, sandbox_id, "8080", timeout
service_account_email, connection_info.load_balancer_hostname, port, timeout
)
response = self.send_command(
http_method="GET",
access_token=http_access_token,
routing_token=routing_token,
sandbox_environment=sandbox_environment,
port=port,
path="/cdp_ws_endpoint",
)
if not response:
raise ValueError("Failed to get the websocket endpoint.")
body_dict = json.loads(response.body)
ws_path = body_dict["endpoint"]

ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
if sandbox_environment and sandbox_environment.connection_info:
connection_info = sandbox_environment.connection_info
if connection_info.load_balancer_hostname:
ws_url = "wss://" + connection_info.load_balancer_hostname
elif connection_info.load_balancer_ip:
ws_url = "ws://" + connection_info.load_balancer_ip
else:
raise ValueError("Load balancer hostname or ip is not available.")
ws_url = ws_url + "/" + ws_path
ws_url = ws_base_url + "/" + ws_path

# port 9222 is the default port for the browser websocket endpoint.
ws_access_token = self.generate_access_token(
service_account_email, sandbox_id, "9222", timeout
service_account_email,
connection_info.load_balancer_hostname,
"9222",
timeout,
)

headers = {}
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
headers["Sec-WebSocket-Protocol"] = (
f"v1.stream, {ws_access_token}, {routing_token}, {port}"
)
return ws_url, headers


Expand Down
25 changes: 25 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11713,6 +11713,20 @@ class CreateAgentEngineSandboxConfig(_common.BaseModel):
default=None,
description="""The TTL for this resource. The expiration time is computed: now + TTL.""",
)
sandbox_environment_template: Optional[str] = Field(
default=None,
description="""The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format:
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}""",
)
sandbox_environment_snapshot: Optional[str] = Field(
default=None,
description="""The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format:
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}""",
)
owner: Optional[str] = Field(
default=None,
description="""Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner.""",
)


class CreateAgentEngineSandboxConfigDict(TypedDict, total=False):
Expand All @@ -11733,6 +11747,17 @@ class CreateAgentEngineSandboxConfigDict(TypedDict, total=False):
ttl: Optional[str]
"""The TTL for this resource. The expiration time is computed: now + TTL."""

sandbox_environment_template: Optional[str]
"""The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format:
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}"""

sandbox_environment_snapshot: Optional[str]
"""The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format:
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}"""

owner: Optional[str]
"""Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner."""


CreateAgentEngineSandboxConfigOrDict = Union[
CreateAgentEngineSandboxConfig, CreateAgentEngineSandboxConfigDict
Expand Down
Loading