Yi Kong | 8328301 | 2023-12-13 12:57:00 +0900 | [diff] [blame^] | 1 | //===- InteractiveModelRunner.h ---- "gym" ML model runner -----*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | |
| 10 | #ifndef LLVM_ANALYSIS_INTERACTIVEMODELRUNNER_H |
| 11 | #define LLVM_ANALYSIS_INTERACTIVEMODELRUNNER_H |
| 12 | |
| 13 | #include "llvm/Analysis/MLModelRunner.h" |
| 14 | #include "llvm/Analysis/TensorSpec.h" |
| 15 | #include "llvm/Analysis/Utils/TrainingLogger.h" |
| 16 | #include "llvm/Config/llvm-config.h" |
| 17 | #include "llvm/Support/FileSystem.h" |
| 18 | #include "llvm/Support/raw_ostream.h" |
| 19 | #include <system_error> |
| 20 | |
| 21 | namespace llvm { |
| 22 | |
| 23 | /// A MLModelRunner that asks for advice from an external agent, or host. It |
| 24 | /// uses 2 files - ideally named pipes - one to send data to that agent, and |
| 25 | /// one to receive advice. |
| 26 | /// The data exchange uses the training logger (Utils/TrainingLogger.h) format. |
| 27 | /// Specifically, the compiler will send the log header, set the context, and |
| 28 | /// send observations; the host is expected to reply with a tensor value after |
| 29 | /// each observation as a binary buffer that's conforming to the shape of the |
| 30 | /// advice. Interleaved, the data closely resembles the training log for a |
| 31 | /// log where we don't capture the reward signal. |
| 32 | /// |
| 33 | /// Note that the correctness of the received data is the responsibility of the |
| 34 | /// host. In particular, if insufficient data were sent, the compiler will block |
| 35 | /// when waiting for an advice. |
| 36 | /// |
| 37 | /// Note that the host can either open the pipes RW, or open first the pipe to |
| 38 | /// the compiler - i.e. the "Inbound" - and then the "Outbound", to avoid |
| 39 | /// deadlock. This is because the compiler first tries to open the inbound |
| 40 | /// (which will hang until there's a writer on the other end). |
| 41 | class InteractiveModelRunner : public MLModelRunner { |
| 42 | public: |
| 43 | InteractiveModelRunner(LLVMContext &Ctx, |
| 44 | const std::vector<TensorSpec> &Inputs, |
| 45 | const TensorSpec &Advice, StringRef OutboundName, |
| 46 | StringRef InboundName); |
| 47 | |
| 48 | static bool classof(const MLModelRunner *R) { |
| 49 | return R->getKind() == MLModelRunner::Kind::Interactive; |
| 50 | } |
| 51 | void switchContext(StringRef Name) override { |
| 52 | Log->switchContext(Name); |
| 53 | Log->flush(); |
| 54 | } |
| 55 | |
| 56 | virtual ~InteractiveModelRunner(); |
| 57 | |
| 58 | private: |
| 59 | void *evaluateUntyped() override; |
| 60 | // This must be declared before InEC if we want to initialize it in the |
| 61 | // ctor initializer list. |
| 62 | int Inbound = -1; |
| 63 | const std::vector<TensorSpec> InputSpecs; |
| 64 | const TensorSpec OutputSpec; |
| 65 | std::error_code OutEC; |
| 66 | std::error_code InEC; |
| 67 | std::vector<char> OutputBuffer; |
| 68 | std::unique_ptr<Logger> Log; |
| 69 | }; |
| 70 | } // namespace llvm |
| 71 | #endif // LLVM_ANALYSIS_INTERACTIVEMODELRUNNER_H |