Databricks SQL Onboarding
Steps
%pip install databricks-sdk --upgrade%restart_pythonimport json
from datetime import datetime
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
from databricks.sdk.service.iam import (
AccessControlRequest,
ComplexValue,
Patch,
PatchOp,
PatchSchema,
PermissionLevel,
)
from pydantic import BaseModel, field_validator
from pyspark.sql import SparkSession
class DatabricksOAuthToken(BaseModel):
id: str
workspace_id: str
oauth_secret: str
created_at: datetime
expires_at: datetime
client_id: str
@field_validator("oauth_secret", "client_id", mode="before")
@classmethod
def strip_whitespace(cls, v: str) -> str:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
class DatabricksCredentials(BaseModel):
oauth_token: DatabricksOAuthToken
workspace_url: str
service_principal_name: str = "espresso-ai-optimizer"
service_principal_id: str
warehouse_id: str
warehouse_name: str
def to_json(self) -> str:
return json.dumps(self.model_dump(mode="json"))
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
workspace_url = f"https://{spark.conf.get('spark.databricks.workspaceUrl') or ''}"
def find_workspace_admin_group(client):
return next(client.groups.list(filter="displayName eq 'admins'")).id
def get_or_create_service_principal(client):
if sps := list(
client.service_principals.list(filter="displayName eq 'espresso-ai-optimizer'")
):
sp = sps[0]
client.service_principals.patch(
id=sp.id,
operations=[
Patch(
op=PatchOp.ADD,
path="entitlements",
value=[
{"value": "databricks-sql-access"},
{"value": "allow-cluster-create"},
],
),
],
schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP],
)
return sp
# No existing service principal was found, so create one
sp = client.service_principals.create(
display_name="espresso-ai-optimizer",
active=True,
entitlements=[
ComplexValue(value="allow-cluster-create"),
ComplexValue(value="databricks-sql-access"),
],
)
return sp
def create_oauth_token(client, service_principal):
token = client.service_principal_secrets_proxy.create(
service_principal_id=service_principal.id,
lifetime=f"{2 * 365 * 24 * 60 * 60}s", # 2 years
)
return DatabricksOAuthToken(
id=token.id,
workspace_id=str(client.get_workspace_id()),
oauth_secret=token.secret,
created_at=token.create_time,
expires_at=token.expire_time,
client_id=service_principal.application_id,
)
def allow_service_principal_to_manage_warehouses(client, service_principal):
for warehouse in client.warehouses.list():
client.permissions.update(
request_object_type="warehouses",
request_object_id=warehouse.id,
access_control_list=[
AccessControlRequest(
service_principal_name=service_principal.application_id,
permission_level=PermissionLevel.CAN_MANAGE,
)
],
)
def make_service_principal_workspace_admin(client, service_principal):
try:
client.groups.patch(
id=find_workspace_admin_group(client),
operations=[
Patch(
op=PatchOp.ADD,
value={"members": [{"value": service_principal.id}]},
)
],
schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP],
)
except Exception as e:
if "already exists" not in str(e).lower() and "duplicate" not in str(e).lower():
raise
def get_or_create_warehouse(client):
for warehouse in client.warehouses.list():
if warehouse.name == "ESPRESSO_AI_WAREHOUSE":
return warehouse.id
created_warehouse = client.warehouses.create_and_wait(
name="ESPRESSO_AI_WAREHOUSE",
cluster_size="X-Small",
auto_stop_mins=1,
enable_serverless_compute=True,
min_num_clusters=1,
max_num_clusters=1,
)
return created_warehouse.id
def is_account_admin_error(e):
return "account admin" in e.lower()
def is_metastore_admin_error(e):
return "does not have MANAGE on catalog" in e.lower() or "metastore" in e.lower()
def allow_service_principal_to_read_system_logs(client, service_principal):
errors = []
def grant(asset_name, asset_type, privilege):
try:
client.grants.update(
full_name=asset_name,
securable_type=asset_type.value,
changes=[
catalog.PermissionsChange(
add=[privilege], principal=service_principal.application_id
)
],
)
except Exception as e:
errors.append(str(e))
grant("system", catalog.SecurableType.CATALOG, catalog.Privilege.USE_CATALOG)
for schema in client.schemas.list(catalog_name="system"):
schema_full_name = f"system.{schema.name}"
grant(schema_full_name, catalog.SecurableType.SCHEMA, catalog.Privilege.USE_SCHEMA)
grant(schema_full_name, catalog.SecurableType.SCHEMA, catalog.Privilege.SELECT)
if any(is_account_admin_error(e) for e in errors):
print("\n⚠️ ACCOUNT ADMIN required: Ask an account admin to grant you access.")
if any(is_metastore_admin_error(e) for e in errors):
print("⚠️ METASTORE ADMIN required: Visit https://accounts.cloud.databricks.com/data")
if __name__ == "__main__":
client = WorkspaceClient()
service_principal = get_or_create_service_principal(client)
oauth_token = create_oauth_token(client, service_principal)
warehouse_id = get_or_create_warehouse(client)
allow_service_principal_to_manage_warehouses(client, service_principal)
make_service_principal_workspace_admin(client, service_principal)
allow_service_principal_to_read_system_logs(client, service_principal)
credentials = DatabricksCredentials(
oauth_token=oauth_token,
workspace_url=workspace_url,
service_principal_name=service_principal.display_name,
service_principal_id=service_principal.id,
warehouse_id=warehouse_id,
warehouse_name="ESPRESSO_AI_WAREHOUSE",
)
print("\n🎉 Setup complete! Here are the Databricks credentials to send to Espresso AI:")
print("=" * 50)
print(credentials.to_json())
print("=" * 50)What the Script Does:
Troubleshooting
Questions?
Last updated