"""
fluxion_ai.core.abstract_workflow
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Defines the AbstractWorkflow class, which serves as the base class for defining and executing workflows.
This module is part of the Fluxion framework and provides core functionality for constructing workflows
with nodes and managing their execution order.
"""
from typing import Dict, List, Any
from abc import ABC, abstractmethod
import os
import tempfile
import uuid
from fluxion_ai.workflows.node import Node
[docs]
class AbstractWorkflow(ABC):
"""
Abstract base class for workflows.
Defines the structure for adding nodes, validating dependencies,
determining execution order, and executing the workflow.
"""
def __init__(self, name: str, workflow_dir: str = None):
"""
Initialize the workflow.
Args:
name (str): The name of the workflow.
"""
self.name = name
self._nodes: Dict[str, Node] = {}
self.initial_inputs = {}
if workflow_dir:
self.workflow_dir = workflow_dir
else:
# Create a temporary directory for the workflow
self.workflow_dir = os.path.join(tempfile.gettempdir(), f"fluxion_workflow_{uuid.uuid4()}")
os.makedirs(self.workflow_dir, exist_ok=True)
@property
def nodes(self):
return self._nodes
@nodes.setter
def nodes(self, nodes):
raise ValueError("Cannot set nodes directly. Use add_node method instead.")
[docs]
@abstractmethod
def define_workflow(self):
"""
Define the workflow by adding nodes and dependencies.
Must be implemented in subclasses.
"""
pass
[docs]
def get_node_by_name(self, name: str):
"""
Get a node by name.
Args:
name (str): The name of the node to retrieve.
Returns:
Node: The node with the given name.
"""
if name not in self.nodes:
raise ValueError(f"Node '{name}' does not exist in the workflow.")
return self.nodes[name]
[docs]
def add_node(self, node):
"""
Add a node to the workflow.
Args:
node (Node): The node to add.
"""
if node.name in self.nodes:
raise ValueError(f"Node '{node.name}' already exists in the workflow.")
for _, dependency in node.get_dependencies(self.nodes).items():
if dependency.name not in self.nodes:
raise ValueError(f"Dependency '{dependency.name}' for node '{node.name}' does not exist.")
self.nodes[node.name] = node
[docs]
def get_node_dependencies(self, node_name: str) -> Dict[str, Node]:
"""
Get the dependencies of a node by name.
Args:
node_name (str): The name of the node.
Returns:
Dict[str, Node]: Dictionary of dependencies for the node.
"""
node = self.get_node_by_name(node_name)
return node.get_dependencies(self.nodes)
[docs]
def get_node_parents(self, node_name: str) -> List[str]:
"""
Get the parent nodes of a node by name.
Args:
node_name (str): The name of the node.
Returns:
List[str]: List of parent node names for the node.
"""
node = self.get_node_by_name(node_name)
return node.get_parents(self.nodes)
def _validate_dependencies(self) -> None:
"""
Validate the dependencies for all nodes in the workflow.
Raises:
ValueError: If a node has an invalid dependency or a circular dependency is detected.
"""
if not self.nodes:
raise ValueError("Workflow has no nodes to validate.")
visited = set()
stack = set()
def visit(node_name):
if node_name in stack:
raise ValueError(f"Circular dependency detected involving '{node_name}'.")
if node_name in visited:
return
stack.add(node_name)
visited.add(node_name)
node = self.nodes.get(node_name)
node_names = [n.name for n in self.nodes.values()]
if not node:
raise ValueError(f"Node '{node_name}' does not exist in the workflow.")
for dependency in node.get_parents(self.nodes):
if dependency.name not in node_names:
raise ValueError(
f"Dependency '{dependency.name}' for node '{node_name}' does not exist. "
f"Available nodes: {node_names}"
)
if not isinstance(dependency, Node):
raise ValueError(f"Dependency '{dependency}' is not a valid Node.")
visit(dependency.name)
stack.remove(node_name)
for node_name in self.nodes:
visit(node_name)
for node in self.nodes.values():
for input_key, node_name in node.inputs.items():
if node_name not in self.nodes:
raise ValueError(f"Input '{input_key}' references non-existent node '{node_name}'.")
def _validate_inputs_and_outputs(self):
"""
Validate that all inputs and outputs in the workflow are consistent.
Raises:
ValueError: If inputs reference non-existent outputs or there are missing inputs.
"""
for node in self.nodes.values():
# Check if inputs are resolved
for input_key, node_name in node.inputs.items():
if node_name not in self.nodes and node_name != "workflow_input":
raise ValueError(f"Input '{input_key}' references non-existent node '{node_name}'.")
[docs]
def determine_execution_order(self) -> List[str]:
"""
Determine the execution order of nodes based on dependencies.
Returns:
List[str]: A list of node names in the order they should be executed.
"""
if not self.nodes:
raise ValueError("Workflow has no nodes to determine execution order.")
order = []
visited = set()
def dfs(node_name):
if node_name in visited:
return
visited.add(node_name)
for dependency in self.nodes[node_name].get_parents(self.nodes):
dfs(dependency.name)
order.append(node_name)
for node_name in self.nodes:
dfs(node_name)
return order
[docs]
def execute(self, inputs: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Execute the workflow.
Args:
inputs (Dict[str, Any], optional): Inputs for the workflow.
Returns:
Dict[str, Any]: Results of the workflow execution.
"""
self._validate_dependencies()
self._validate_inputs_and_outputs()
execution_order = self.determine_execution_order()
results = {}
for node_name in execution_order:
node = self.nodes[node_name]
results[node_name] = node.execute(results=results, inputs=inputs)
return results
[docs]
def visualize(self, output_path: str = "workflow_graph", format: str = "png"):
"""
Visualizes the workflow as a directed graph.
Args:
output_path (str): The output path for the generated graph (without extension).
format (str): The format of the output file (e.g., 'png', 'pdf').
Returns:
str: Path to the generated visualization file.
"""
try:
from graphviz import Digraph
except ImportError:
raise ImportError("The 'graphviz' package is required for workflow visualization. "
"Please install it using 'pip install graphviz'.")
dot = Digraph(name=self.name, format=format)
dot.attr(rankdir='LR')
# Add nodes to the graph
for node in self.nodes.values():
dot.node(node.name, label=node.name)
# Add edges to represent dependencies
for node in self.nodes.values():
for dependency in node.get_parents(self.nodes):
dot.edge(dependency.name, node.name)
# Render the graph
output_file = dot.render(filename=output_path, cleanup=True)
print(f"Workflow visualization saved to: {output_file}")
return output_file
def __del__(self):
"""
Clean up the temporary workflow directory.
"""
print("Deleting the workflow directory...")
if os.path.exists(self.workflow_dir):
os.rmdir(self.workflow_dir)
print(f"Deleted temporary workflow directory: {self.workflow_dir}")