from __future__ import annotations
import asyncio
import logging
import pickle
import time
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional
from threadpoolctl import threadpool_limits
from dataclay import utils
from dataclay.config import LEGACY_DEPS, set_runtime, settings
from dataclay.event_loop import dc_to_thread_io
from dataclay.exceptions import (
DataClayException,
DoesNotExistError,
NoOtherBackendsAvailable,
ObjectWithWrongBackendIdError,
)
from dataclay.lock_manager import lock_manager
from ..metadata.kvdata import ObjectMetadata
from dataclay.runtime import BackendRuntime
from dataclay.utils.serialization import dcdumps, dcloads, recursive_dcloads
from dataclay.utils.telemetry import trace
from dataclay.event_loop import get_dc_event_loop
if TYPE_CHECKING:
from uuid import UUID
from dataclay.dataclay_object import DataClayObject
tracer = trace.get_tracer(__name__)
logger: logging.Logger = utils.LoggerEvent(logging.getLogger(__name__))
[docs]
class BackendAPI:
def __init__(self, name: str, port: int, backend_id: UUID, kv_host: str, kv_port: int):
# NOTE: the port is (atm) exclusively for unique identification of an EE
# (given that the name is shared between all EE that share a SL, which happens in HPC deployments)
self.name = name
self.port = port
# Initialize runtime
self.backend_id = backend_id
self.runtime = BackendRuntime(kv_host, kv_port, self.backend_id)
set_runtime(self.runtime)
async def _is_ready(self, timeout, pause):
ref = time.time()
now = ref
if await self.runtime.metadata_service.is_ready(timeout):
# Check that dataclay_id is defined. If it is not defined, it could break things
while timeout is None or (now - ref) < timeout:
try:
dataclay_obj = await self.runtime.metadata_service.get_dataclay("this")
settings.dataclay_id = dataclay_obj.id
return True
except DoesNotExistError:
time.sleep(pause)
now = time.time()
return False
[docs]
async def is_ready(self, timeout: Optional[float] = None, pause: float = 0.5):
future = asyncio.run_coroutine_threadsafe(
self._is_ready(timeout, pause), get_dc_event_loop()
)
return await asyncio.wrap_future(future)
# Object Methods
[docs]
async def register_objects(self, serialized_objects: Iterable[bytes], make_replica: bool):
logger.debug("Receiving (%d) objects to register", len(serialized_objects))
for object_bytes in serialized_objects:
metadata_dict, dc_properties, getstate = await dcloads(object_bytes)
if LEGACY_DEPS:
dc_meta = ObjectMetadata.parse_obj(metadata_dict)
else:
dc_meta = ObjectMetadata.model_validate(metadata_dict)
instance = await self.runtime.get_object_by_id(dc_meta.id)
if instance._dc_is_local:
assert instance._dc_is_replica
if make_replica:
logger.warning("Replica already exists with id=%s", instance._dc_meta.id)
continue
state = {"_dc_meta": dc_meta}
async with lock_manager.get_lock(instance._dc_meta.id).writer_lock:
# Update object state and flags
state["_dc_is_loaded"] = True
state["_dc_is_local"] = True
vars(instance).update(state)
if getstate:
instance.__setstate__(getstate)
else:
vars(instance).update(dc_properties)
self.runtime.data_manager.add_hard_reference(instance)
if make_replica:
instance._dc_is_replica = True
instance._dc_meta.replica_backend_ids.add(self.backend_id)
else:
# If not make_replica then its a move
instance._dc_meta.master_backend_id = self.backend_id
instance._dc_meta.replica_backend_ids.discard(self.backend_id)
# we can only move masters
# instance._dc_is_replica = False # already set by vars(instance).update(state)
# ¿Should be always updated here, or from the calling backend?
await self.runtime.metadata_service.upsert_object(instance._dc_meta)
[docs]
@tracer.start_as_current_span("make_persistent")
async def make_persistent(self, serialized_objects: Iterable[bytes]):
logger.debug("Receiving (%d) objects to make persistent", len(serialized_objects))
unserialized_objects: dict[UUID, DataClayObject] = {}
for object_bytes in serialized_objects:
proxy_object = await recursive_dcloads(object_bytes, unserialized_objects)
proxy_object._dc_is_local = True
proxy_object._dc_is_loaded = True
proxy_object._dc_meta.master_backend_id = self.backend_id
assert len(serialized_objects) == len(unserialized_objects)
for proxy_object in unserialized_objects.values():
logger.debug(
"(%s) Registering %s",
proxy_object._dc_meta.id,
proxy_object.__class__.__name__,
)
self.runtime.inmemory_objects[proxy_object._dc_meta.id] = proxy_object
self.runtime.data_manager.add_hard_reference(proxy_object)
await self.runtime.metadata_service.upsert_object(proxy_object._dc_meta)
proxy_object._dc_is_registered = True
[docs]
@tracer.start_as_current_span("call_active_method")
async def call_active_method(
self,
object_id: UUID,
method_name: str,
args: tuple[Any],
kwargs: dict[str, Any],
exec_constraints: dict[str, Any],
) -> tuple[bytes, bool]:
"""Entry point for calling an active method of a DataClayObject"""
logger.debug("(%s) Receiving remote call to activemethod '%s'", object_id, method_name)
instance = await self.runtime.get_object_by_id(object_id)
# If the object isn't local (not owned by this backend), a custom exception is sent to the
# client to update the object's backend_id, and call_active_method again to the correct backend
if not instance._dc_is_local:
# NOTE: We sync the metadata because when consolidating an object
# it might be that the proxy is pointing to the wrong backend_id which is
# the same current backend, creating a infinite loop. This could be solve also
# by passing the backend_id of the new object to the proxy, but this can create
# problems with race conditions (e.g. a move before the consolidation). Therefore,
# we check to the metadata which is more reliable.
logger.warning("(%s) Wrong backend", object_id)
await self.runtime.sync_object_metadata(instance)
logger.warning(
"(%s) Update backend to %s", object_id, instance._dc_meta.master_backend_id
)
return (
pickle.dumps(
ObjectWithWrongBackendIdError(
instance._dc_meta.master_backend_id, instance._dc_meta.replica_backend_ids
)
),
False,
)
# Deserialize arguments
args, kwargs = await asyncio.gather(dcloads(args), dcloads(kwargs))
# Call activemethod in another thread
logger.info("(%s) *** Starting activemethod '%s' in executor", object_id, method_name)
max_threads = (
None if exec_constraints.get("max_threads", 0) == 0 else exec_constraints["max_threads"]
)
logger.info("(%s) Max threads for activemethod: %s", object_id, max_threads)
# TODO: Check that the threadpool_limit is not limiting our internal pool of threads.
# like when we are serializing dataclay objects.
with threadpool_limits(limits=max_threads):
try:
func = getattr(instance, method_name)
if asyncio.iscoroutinefunction(func):
logger.debug("(%s) Awaiting activemethod coroutine", object_id)
result = await func(*args, **kwargs)
else:
logger.debug("(%s) Running activemethod in new thread", object_id)
result = await dc_to_thread_io(func, *args, **kwargs)
except Exception as e:
# If an exception was raised, serialize it and return it to be raised by the client
logger.info("(%s) *** Exception in activemethod '%s'", object_id, method_name)
try:
return pickle.dumps(e), True
except TypeError:
# If the exception can't be serialized, do your best
return pickle.dumps(type(e)(str(e))), True
logger.info("(%s) *** Finished activemethod '%s' in executor", object_id, method_name)
# Serialize the result if not None
if result is not None:
result = await dcdumps(result)
return result, False
# Store Methods
[docs]
@tracer.start_as_current_span("get_object_attribute")
async def get_object_attribute(self, object_id: UUID, attribute: str) -> tuple[bytes, bool]:
"""Returns value of the object attibute with ID provided
Args:
object_id: ID of the object
attribute: Name of the attibute
Returns:
The pickled value of the object attibute.
If it's an exception or not
"""
logger.debug("(%s) Receiving remote call to __getattribute__ '%s'", object_id, attribute)
instance = await self.runtime.get_object_by_id(object_id)
if not instance._dc_is_local:
logger.warning("(%s) Wrong backend", object_id)
await self.runtime.sync_object_metadata(instance)
logger.warning(
"(%s) Update backend to %s", object_id, instance._dc_meta.master_backend_id
)
return (
pickle.dumps(
ObjectWithWrongBackendIdError(
instance._dc_meta.master_backend_id, instance._dc_meta.replica_backend_ids
)
),
False,
)
try:
value = await dc_to_thread_io(getattr, instance, attribute)
return await dcdumps(value), False
except Exception as e:
return pickle.dumps(e), True
[docs]
@tracer.start_as_current_span("set_object_attribute")
async def set_object_attribute(
self, object_id: UUID, attribute: str, serialized_attribute: bytes
) -> tuple[bytes, bool]:
"""Updates an object attibute with ID provided"""
logger.debug("(%s) Receiving remote call to __setattr__ '%s'", object_id, attribute)
instance = await self.runtime.get_object_by_id(object_id)
if not instance._dc_is_local:
logger.warning("(%s) Wrong backend", object_id)
await self.runtime.sync_object_metadata(instance)
logger.warning(
"(%s) Update backend to %s", object_id, instance._dc_meta.master_backend_id
)
return (
pickle.dumps(
ObjectWithWrongBackendIdError(
instance._dc_meta.master_backend_id, instance._dc_meta.replica_backend_ids
)
),
False,
)
try:
object_attribute = await dcloads(serialized_attribute)
await dc_to_thread_io(setattr, instance, attribute, object_attribute)
return None, False
except Exception as e:
return pickle.dumps(e), True
[docs]
@tracer.start_as_current_span("del_object_attribute")
async def del_object_attribute(self, object_id: UUID, attribute: str) -> tuple[bytes, bool]:
"""Deletes an object attibute with ID provided"""
logger.debug("(%s) Receiving remote call to __delattr__'%s'", object_id, attribute)
instance = await self.runtime.get_object_by_id(object_id)
if not instance._dc_is_local:
logger.warning("(%s) Wrong backend", object_id)
await self.runtime.sync_object_metadata(instance)
logger.warning(
"(%s) Update backend to %s", object_id, instance._dc_meta.master_backend_id
)
return (
pickle.dumps(
ObjectWithWrongBackendIdError(
instance._dc_meta.master_backend_id, instance._dc_meta.replica_backend_ids
)
),
False,
)
try:
await dc_to_thread_io(delattr, instance, attribute)
return None, False
except Exception as e:
return pickle.dumps(e), True
[docs]
@tracer.start_as_current_span("get_object_properties")
async def get_object_properties(self, object_id: UUID) -> bytes:
"""Returns the properties of the object with ID provided
Args:
object_id: ID of the object
Returns:
The pickled properties of the object.
"""
instance = await self.runtime.get_object_by_id(object_id)
object_properties = await self.runtime.get_object_properties(instance)
return await dcdumps(object_properties)
[docs]
@tracer.start_as_current_span("update_object_properties")
async def update_object_properties(self, object_id: UUID, serialized_properties: bytes):
"""Updates an object with ID provided with contents from another object"""
instance, object_properties = await asyncio.gather(
self.runtime.get_object_by_id(object_id), dcloads(serialized_properties)
)
await self.runtime.update_object_properties(instance, object_properties)
[docs]
async def new_object_version(self, object_id: UUID):
"""Creates a new version of the object with ID provided
This entrypoint for new_version is solely for COMPSs (called from java).
Args:
object_id: ID of the object to create a new version from.
Returns:
The JSON-encoded metadata of the new DataClayObject version.
"""
instance = await self.runtime.get_object_by_id(object_id)
new_version = await self.runtime.new_object_version(instance)
return new_version.getID()
[docs]
async def consolidate_object_version(self, object_id: UUID):
"""Consolidates the object with ID provided"""
instance = await self.runtime.get_object_by_id(object_id)
await self.runtime.consolidate_version(instance)
[docs]
@tracer.start_as_current_span("proxify_object")
async def proxify_object(self, object_id: UUID, new_object_id: UUID):
"""Proxify object with ID provided to new object ID"""
logger.debug("Proxifying object %s to %s", object_id, new_object_id)
instance = await self.runtime.get_object_by_id(object_id)
await self.runtime.proxify_object(instance, new_object_id)
[docs]
@tracer.start_as_current_span("change_object_id")
async def change_object_id(self, object_id: UUID, new_object_id: UUID):
instance = await self.runtime.get_object_by_id(object_id)
await self.runtime.change_object_id(instance, new_object_id)
[docs]
@tracer.start_as_current_span("send_objects")
async def send_objects(
self,
object_ids: Iterable[UUID],
backend_id: UUID,
make_replica: bool,
recursive: bool,
remotes: bool,
):
logger.debug("Receiving objects to %s", "replicate" if make_replica else "move")
# Use asyncio.gather to call get_object_by_id concurrently for all object_ids
instances = await asyncio.gather(
*[self.runtime.get_object_by_id(object_id) for object_id in object_ids]
)
await self.runtime.send_objects(instances, backend_id, make_replica, recursive, remotes)
# Shutdown
[docs]
@tracer.start_as_current_span("stop")
async def stop(self):
await self.runtime.stop()
[docs]
@tracer.start_as_current_span("flush_all")
async def flush_all(self):
await self.runtime.data_manager.flush_all()
[docs]
@tracer.start_as_current_span("move_all_objects")
async def move_all_objects(self):
dc_objects = await self.runtime.metadata_service.get_all_objects()
await self.runtime.backend_clients.update()
backends = self.runtime.backend_clients
if len(backends) <= 1:
raise NoOtherBackendsAvailable()
num_objects = len(dc_objects)
mean = -(num_objects // -(len(backends) - 1))
backends_objects = {backend_id: [] for backend_id in backends.keys()}
for object_md in dc_objects.values():
backends_objects[object_md.master_backend_id].append(object_md.id)
backends_diff = {}
for backend_id, objects in backends_objects.items():
diff = len(objects) - mean
backends_diff[backend_id] = diff
# for backend_id, object_ids in backends_objects.items():
# if backends_diff[backend_id] <= 0:
# continue
object_ids = backends_objects[self.backend_id]
for new_backend_id in backends_objects.keys():
if new_backend_id == self.backend_id or backends_diff[new_backend_id] >= 0:
continue
while backends_diff[new_backend_id] < 0:
object_id = object_ids.pop()
self.move_object(object_id, new_backend_id, None)
backends_diff[new_backend_id] += 1
# Replicas
[docs]
async def new_object_replica(
self,
object_id: UUID,
backend_id: UUID = None,
recursive: bool = False,
remotes: bool = True,
):
instance = await self.runtime.get_object_by_id(object_id)
await self.runtime.new_object_replica(instance, backend_id, recursive, remotes)