Update CodeGeneratorAtomicFlow.py
Browse files- CodeGeneratorAtomicFlow.py +19 -0
CodeGeneratorAtomicFlow.py
CHANGED
|
@@ -41,6 +41,10 @@ class CodeGeneratorAtomicFlow(ChatAtomicFlow):
|
|
| 41 |
|
| 42 |
@classmethod
|
| 43 |
def instantiate_from_config(cls, config):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
flow_config = deepcopy(config)
|
| 45 |
|
| 46 |
kwargs = {"flow_config": flow_config}
|
|
@@ -55,12 +59,20 @@ class CodeGeneratorAtomicFlow(ChatAtomicFlow):
|
|
| 55 |
return cls(**kwargs)
|
| 56 |
|
| 57 |
def _get_code_library_file(self, input_data: Dict[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
assert "memory_files" in input_data, "memory_files not passed to CodeGeneratorAtomicFlow"
|
| 59 |
assert "code_library" in input_data['memory_files'], "code_library not in memory_files"
|
| 60 |
code_library_file_location = input_data['memory_files']['code_library']
|
| 61 |
return code_library_file_location
|
| 62 |
|
| 63 |
def _get_code_library_content(self, input_data: Dict[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
assert "code_library" in input_data, "code_library not passed to CodeGeneratorAtomicFlow"
|
| 65 |
code_library = input_data['code_library']
|
| 66 |
if len(code_library) == 0:
|
|
@@ -68,6 +80,9 @@ class CodeGeneratorAtomicFlow(ChatAtomicFlow):
|
|
| 68 |
return code_library
|
| 69 |
|
| 70 |
def _update_prompts_and_input(self, input_data: Dict[str, Any]):
|
|
|
|
|
|
|
|
|
|
| 71 |
if 'goal' in input_data:
|
| 72 |
input_data['goal'] += self.hint_for_model
|
| 73 |
code_library_file_location = self._get_code_library_file(input_data)
|
|
@@ -78,6 +93,10 @@ class CodeGeneratorAtomicFlow(ChatAtomicFlow):
|
|
| 78 |
)
|
| 79 |
|
| 80 |
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
self._update_prompts_and_input(input_data)
|
| 82 |
|
| 83 |
|
|
|
|
| 41 |
|
| 42 |
@classmethod
|
| 43 |
def instantiate_from_config(cls, config):
|
| 44 |
+
"""Instantiate a CodeGeneratorAtomicFlow from a configuration.
|
| 45 |
+
:param config: Configuration dictionary.
|
| 46 |
+
:return: Instantiated CodeGeneratorAtomicFlow.
|
| 47 |
+
"""
|
| 48 |
flow_config = deepcopy(config)
|
| 49 |
|
| 50 |
kwargs = {"flow_config": flow_config}
|
|
|
|
| 59 |
return cls(**kwargs)
|
| 60 |
|
| 61 |
def _get_code_library_file(self, input_data: Dict[str, Any]):
|
| 62 |
+
"""Get the code library file location from the input data.
|
| 63 |
+
:param input_data: Input data.
|
| 64 |
+
:return: Code library file location.
|
| 65 |
+
"""
|
| 66 |
assert "memory_files" in input_data, "memory_files not passed to CodeGeneratorAtomicFlow"
|
| 67 |
assert "code_library" in input_data['memory_files'], "code_library not in memory_files"
|
| 68 |
code_library_file_location = input_data['memory_files']['code_library']
|
| 69 |
return code_library_file_location
|
| 70 |
|
| 71 |
def _get_code_library_content(self, input_data: Dict[str, Any]):
|
| 72 |
+
"""Get the code library content from the input data.
|
| 73 |
+
:param input_data: Input data.
|
| 74 |
+
:return: Code library content.
|
| 75 |
+
"""
|
| 76 |
assert "code_library" in input_data, "code_library not passed to CodeGeneratorAtomicFlow"
|
| 77 |
code_library = input_data['code_library']
|
| 78 |
if len(code_library) == 0:
|
|
|
|
| 80 |
return code_library
|
| 81 |
|
| 82 |
def _update_prompts_and_input(self, input_data: Dict[str, Any]):
|
| 83 |
+
"""Update the prompts and input data.
|
| 84 |
+
:param input_data: Input data.
|
| 85 |
+
"""
|
| 86 |
if 'goal' in input_data:
|
| 87 |
input_data['goal'] += self.hint_for_model
|
| 88 |
code_library_file_location = self._get_code_library_file(input_data)
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 96 |
+
"""Run the flow.
|
| 97 |
+
:param input_data: Input data.
|
| 98 |
+
:return: Output data.
|
| 99 |
+
"""
|
| 100 |
self._update_prompts_and_input(input_data)
|
| 101 |
|
| 102 |
|