Databricks SQL Onboarding
Steps
%pip install databricks-sdk --upgrade%restart_pythonMULTI_WORKSPACE = True
import json
import time
from datetime import datetime
from databricks.sdk import AccountClient, WorkspaceClient
from databricks.sdk.errors import BadRequest
from databricks.sdk.service import catalog
from databricks.sdk.service.iam import (
ComplexValue,
Patch,
PatchOp,
PatchSchema,
WorkspacePermission,
)
from pydantic import BaseModel, field_validator
from pyspark.sql import SparkSession
class DatabricksOAuthToken(BaseModel):
id: str
oauth_secret: str
created_at: datetime
expires_at: datetime
client_id: str
@field_validator("oauth_secret", "client_id", mode="before")
@classmethod
def strip_whitespace(cls, v: str) -> str:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
class DatabricksCredentials(BaseModel):
oauth_token: DatabricksOAuthToken
workspace_url: str
service_principal_name: str = "espresso-ai-optimizer"
service_principal_id: str
warehouse_id: str | None = None
warehouse_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
def to_json(self) -> str:
return json.dumps(self.model_dump(mode="json"))
def _is_duplicate_error(e):
msg = str(e).lower()
return "already exists" in msg or "duplicate" in msg
def _patch_sp_entitlements(client, sp):
client.service_principals.patch(
id=sp.id,
operations=[
Patch(
op=PatchOp.ADD,
path="entitlements",
value=[
{"value": "databricks-sql-access"},
{"value": "allow-cluster-create"},
],
),
],
schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP],
)
def get_or_create_service_principal(client, name="espresso-ai-optimizer"):
if sps := list(client.service_principals.list(filter=f"displayName eq '{name}'")):
sp = sps[0]
_patch_sp_entitlements(client, sp)
return sp
return client.service_principals.create(
display_name=name,
active=True,
entitlements=[
ComplexValue(value="allow-cluster-create"),
ComplexValue(value="databricks-sql-access"),
],
)
def find_service_principal(client, name="espresso-ai-optimizer", timeout_secs=60):
deadline = time.monotonic() + timeout_secs
while time.monotonic() < deadline:
sps = list(client.service_principals.list(filter=f"displayName eq '{name}'"))
if sps:
return sps[0]
time.sleep(5)
raise RuntimeError(
f"Service principal {name!r} not visible after {timeout_secs}s. "
"Workspace assignment may not have propagated."
)
def create_oauth_token(client, service_principal):
token = client.service_principal_secrets_proxy.create(
service_principal_id=service_principal.id,
lifetime=f"{2 * 365 * 24 * 60 * 60}s", # 2 years
)
return DatabricksOAuthToken(
id=token.id,
oauth_secret=token.secret,
created_at=token.create_time,
expires_at=token.expire_time,
client_id=service_principal.application_id,
)
def make_service_principal_workspace_admin(client, service_principal):
try:
admin_group_id = next(client.groups.list(filter="displayName eq 'admins'")).id
client.groups.patch(
id=admin_group_id,
operations=[
Patch(
op=PatchOp.ADD,
value={"members": [{"value": service_principal.id}]},
)
],
schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP],
)
except Exception as e:
if not _is_duplicate_error(e):
raise
def get_or_create_warehouse(client):
for warehouse in client.warehouses.list():
if warehouse.name == "ESPRESSO_AI_WAREHOUSE":
return warehouse.id
return client.warehouses.create_and_wait(
name="ESPRESSO_AI_WAREHOUSE",
cluster_size="X-Small",
auto_stop_mins=1,
enable_serverless_compute=True,
min_num_clusters=1,
max_num_clusters=1,
).id
def allow_service_principal_to_read_system_logs(client, service_principal):
errors = []
def grant(asset_name, asset_type, privilege):
try:
client.grants.update(
full_name=asset_name,
securable_type=asset_type.value,
changes=[
catalog.PermissionsChange(
add=[privilege], principal=service_principal.application_id
)
],
)
except Exception as e:
errors.append(str(e).lower())
grant("system", catalog.SecurableType.CATALOG, catalog.Privilege.USE_CATALOG)
for schema in client.schemas.list(catalog_name="system"):
schema_full_name = f"system.{schema.name}"
grant(schema_full_name, catalog.SecurableType.SCHEMA, catalog.Privilege.USE_SCHEMA)
grant(schema_full_name, catalog.SecurableType.SCHEMA, catalog.Privilege.SELECT)
if not errors:
return
if any("account admin" in e for e in errors):
print("\n⚠️ ACCOUNT ADMIN required: Ask an account admin to grant you access.")
if any("manage on catalog" in e or "metastore" in e for e in errors):
print("⚠️ METASTORE ADMIN required: Visit https://accounts.cloud.databricks.com/data")
raise RuntimeError(
"Failed to grant system table access (resolve the above and re-run):\n - "
+ "\n - ".join(errors)
)
def wait_for_account_admin(account_id, client_id, secret):
deadline = time.monotonic() + 300 # 5 minutes
while time.monotonic() < deadline:
try:
ac = AccountClient(
host="https://accounts.cloud.databricks.com",
account_id=account_id,
client_id=client_id,
client_secret=secret,
)
workspaces = [
(ws.workspace_id, ws.workspace_name, ws.deployment_name)
for ws in ac.workspaces.list()
]
print("Account admin access confirmed.")
return ac, workspaces
except Exception:
time.sleep(5)
raise TimeoutError(
"Account admin not granted within 5 minutes. Grant the role via the link above and re-run."
)
def filter_to_current_metastore(account_client, workspace_client, workspaces):
metastore_id = workspace_client.metastores.current().metastore_id
metastore_ws_ids = set(account_client.metastore_assignments.list(metastore_id))
return [ws for ws in workspaces if ws[0] in metastore_ws_ids]
def check_workspaces_federated(account_client, workspaces):
not_federated = []
for ws_id, ws_name, _ in workspaces:
try:
account_client.workspace_assignment.list(workspace_id=ws_id)
except BadRequest as e:
if "permission assignment apis are not available" not in str(e).lower():
raise
not_federated.append((ws_id, ws_name))
if not_federated:
details = "\n - ".join(f"{name} ({wid})" for wid, name in not_federated)
raise PermissionError(
"IDENTITY FEDERATION required on the following workspaces:\n"
f" - {details}\n\n"
"Enable identity federation via the account console:\n"
" Workspaces → <workspace> → Configuration tab → "
"'Identity federation' must show 'Enabled'.\n"
)
def assign_sp_to_workspaces(account_client, sp_id, workspaces):
for ws_id, ws_name, _ in workspaces:
deadline = time.monotonic() + 60
while True:
try:
account_client.workspace_assignment.update(
workspace_id=ws_id,
principal_id=sp_id,
permissions=[WorkspacePermission.ADMIN],
)
print(f" Assigned as admin on {ws_name} ({ws_id})")
break
except Exception as e:
if _is_duplicate_error(e):
break
if time.monotonic() >= deadline:
raise
time.sleep(5)
def setup_workspace(oauth_token, workspace_url, workspace_name, workspace_id, sp_name):
ws_client = WorkspaceClient(
host=workspace_url,
client_id=oauth_token.client_id,
client_secret=oauth_token.oauth_secret,
)
sp = find_service_principal(ws_client, sp_name)
_patch_sp_entitlements(ws_client, sp)
make_service_principal_workspace_admin(ws_client, sp)
allow_service_principal_to_read_system_logs(ws_client, sp)
wh_id = get_or_create_warehouse(ws_client)
return wh_id
if __name__ == "__main__":
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
workspace_url = f"https://{spark.conf.get('spark.databricks.workspaceUrl') or ''}"
client = WorkspaceClient()
service_principal = get_or_create_service_principal(client)
oauth_token = create_oauth_token(client, service_principal)
warehouse_id = get_or_create_warehouse(client)
make_service_principal_workspace_admin(client, service_principal)
allow_service_principal_to_read_system_logs(client, service_principal)
credentials_by_workspace: dict[str, DatabricksCredentials] = {}
if not MULTI_WORKSPACE:
workspace_id = str(client.get_workspace_id())
workspace_name = spark.sql(
f"SELECT workspace_name FROM system.access.workspaces_latest "
f"WHERE workspace_id = {workspace_id}"
).collect()[0][0]
credentials_by_workspace[workspace_name] = DatabricksCredentials(
oauth_token=oauth_token,
workspace_url=workspace_url,
workspace_id=workspace_id,
workspace_name=workspace_name,
service_principal_name=service_principal.display_name,
service_principal_id=service_principal.id,
warehouse_id=warehouse_id,
warehouse_name="ESPRESSO_AI_WAREHOUSE",
)
else:
account_id = spark.sql(
"SELECT account_id FROM system.billing.usage LIMIT 1"
).collect()[0][0]
temp_sp = get_or_create_service_principal(client, "espresso-ai-temp")
temp_secret = client.service_principal_secrets_proxy.create(
service_principal_id=temp_sp.id, lifetime=f"{60 * 60}s"
).secret
roles_url = (
f"https://accounts.cloud.databricks.com/user-management/"
f"serviceprincipals/{temp_sp.id}/roles?account_id={account_id}"
)
print(
f"\n📋 To continue, grant the temporary service principal account admin access."
f"\n Open this link and add the 'Account admin' role:\n\n {roles_url}\n"
)
account_client, workspaces = wait_for_account_admin(
account_id, temp_sp.application_id, temp_secret
)
workspaces = filter_to_current_metastore(account_client, client, workspaces)
check_workspaces_federated(account_client, workspaces)
assign_sp_to_workspaces(account_client, int(service_principal.id), workspaces)
client.service_principals.delete(id=temp_sp.id)
print("Temporary service principal deleted.")
current_ws_id = str(client.get_workspace_id())
for ws_id, ws_name, deployment_name in workspaces:
ws_id_str = str(ws_id)
workspace_url_for_ws = f"https://{deployment_name}.cloud.databricks.com"
if ws_id_str == current_ws_id:
ws_warehouse_id = warehouse_id
else:
ws_warehouse_id = setup_workspace(
oauth_token=oauth_token,
workspace_url=workspace_url_for_ws,
workspace_name=ws_name,
workspace_id=ws_id_str,
sp_name=service_principal.display_name,
)
credentials_by_workspace[ws_name] = DatabricksCredentials(
oauth_token=oauth_token,
workspace_url=workspace_url_for_ws,
workspace_id=ws_id_str,
workspace_name=ws_name,
service_principal_name=service_principal.display_name,
service_principal_id=service_principal.id,
warehouse_id=ws_warehouse_id,
warehouse_name="ESPRESSO_AI_WAREHOUSE" if ws_warehouse_id else None,
)
print("\n🎉 Setup complete! Here are the Databricks credentials to send to Espresso AI:")
print("=" * 50)
print(
json.dumps(
{
ws: cred.model_dump(mode="json")
for ws, cred in credentials_by_workspace.items()
},
indent=2,
)
)
print("=" * 50)
What the Script Does:
Troubleshooting
Questions?
Last updated