Spaces:
Sleeping
Sleeping
| """Client script. | |
| This script does the following: | |
| - Query crypto-parameters and pre/post-processing parameters (client.zip) | |
| - Quantize the inputs using the parameters | |
| - Encrypt data using the crypto-parameters | |
| - Send the encrypted data to the server (async using grequests) | |
| - Collect the data and decrypt it | |
| - De-quantize the decrypted results | |
| """ | |
| import io | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import grequests | |
| import numpy | |
| import requests | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from concrete.ml.deployment import FHEModelClient | |
| PORT = os.environ.get("PORT", "5000") | |
| IP = os.environ.get("IP", "localhost") | |
| URL = os.environ.get("URL", f"http://{IP}:{PORT}") | |
| NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1)) | |
| STATUS_OK = 200 | |
| def main(): | |
| # Get the necessary data for the client | |
| # client.zip | |
| train_sub_set = ... | |
| zip_response = requests.get(f"{URL}/get_client") | |
| assert zip_response.status_code == STATUS_OK | |
| with open("./client.zip", "wb") as file: | |
| file.write(zip_response.content) | |
| # Get the data to infer | |
| X = train_sub_set[:1] | |
| # Create the client | |
| client = FHEModelClient(path_dir="./", key_dir="./keys") | |
| # The client first need to create the private and evaluation keys. | |
| serialized_evaluation_keys = client.get_serialized_evaluation_keys() | |
| assert isinstance(serialized_evaluation_keys, bytes) | |
| # Evaluation keys can be quite large files but only have to be shared once with the server. | |
| # Check the size of the evaluation keys (in MB) | |
| print(f"Evaluation keys size: {sys.getsizeof(serialized_evaluation_keys) / 1024 / 1024:.2f} MB") | |
| # Update all base64 queries encodings with UploadFile | |
| response = requests.post( | |
| f"{URL}/add_key", | |
| files={"key": io.BytesIO(initial_bytes=serialized_evaluation_keys)}, | |
| ) | |
| assert response.status_code == STATUS_OK | |
| uid = response.json()["uid"] | |
| inferences = [] | |
| # Launch the queries | |
| clear_input = X[[0], :].numpy() | |
| print("Input shape:", clear_input.shape) | |
| assert isinstance(clear_input, numpy.ndarray) | |
| print("Quantize/Encrypt") | |
| encrypted_input = client.quantize_encrypt_serialize(clear_input) # Encrypt the data | |
| assert isinstance(encrypted_input, bytes) | |
| print(f"Encrypted input size: {sys.getsizeof(encrypted_input) / 1024 / 1024:.2f} MB") | |
| print("Posting query") | |
| inferences.append( | |
| grequests.post( | |
| f"{URL}/compute", | |
| files={ | |
| "model_input": io.BytesIO(encrypted_input), | |
| }, | |
| data={ | |
| "uid": uid, | |
| }, | |
| ) | |
| ) | |
| del encrypted_input | |
| del serialized_evaluation_keys | |
| print("Posted!") | |
| # Unpack the results | |
| decrypted_predictions = [] | |
| for result in grequests.map(inferences): | |
| if result is None: | |
| raise ValueError( | |
| "Result is None, probably because the server crashed due to lack of available memory." | |
| ) | |
| assert result.status_code == STATUS_OK | |
| print("OK!") | |
| encrypted_result = result.content | |
| decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_result)[0] | |
| decrypted_predictions.append(decrypted_prediction) | |
| print(decrypted_predictions) | |
| if __name__ == "__main__": | |
| main() | |